Initial version of the open-source distributed TensorFlow runtime.
This includes a gRPC server (grpc_tensorflow_server) that can serve as both the master of a distributed TensorFlow computation, and an individual worker in the computation. The GrpcSession class is included to allow client programs (including Python clients) to interact with a server. See tensorflow/core/distributed_runtime/README.md for usage instructions. This change partially addresses issue #23. Change: 115634191
This commit is contained in:
parent
d27da251bc
commit
00986d48bb
WORKSPACE
tensorflow
core
BUILD
distributed_runtime
BUILDREADME.mdbase_rendezvous_mgr.ccbase_rendezvous_mgr.hbuild_graph_options.ccbuild_graph_options.hcall_options.cccall_options.hcall_options_test.ccexecutor_test.ccgraph_mgr.ccgraph_mgr.hmaster.ccmaster.hmaster_env.hmaster_interface.hmaster_session.ccmaster_session.hmaster_session_interface.hmaster_test.ccprocess_util.ccprocess_util.hremote_device.ccremote_device.hremote_device_test.ccrendezvous_mgr_interface.h
rpc
BUILDasync_service_interface.hgrpc_call.hgrpc_channel.ccgrpc_channel.hgrpc_channel_test.ccgrpc_client_cq_tag.hgrpc_master_service.ccgrpc_master_service.hgrpc_remote_master.ccgrpc_remote_master.hgrpc_remote_worker.ccgrpc_remote_worker.hgrpc_server_lib.ccgrpc_server_lib.hgrpc_session.ccgrpc_session.hgrpc_session_test.ccgrpc_tensorflow_server.ccgrpc_tensorflow_server_lib.ccgrpc_testlib.ccgrpc_testlib.hgrpc_testlib_ops.ccgrpc_testlib_server.ccgrpc_util.hgrpc_worker_cache.ccgrpc_worker_cache.hgrpc_worker_service.ccgrpc_worker_service.hrpc_rendezvous_mgr.ccrpc_rendezvous_mgr.hrpc_rendezvous_mgr_test.cc
simple_graph_execution_state.ccsimple_graph_execution_state.hworker_cache.hworker_cache_logger.ccworker_cache_logger.hworker_cache_partial.ccworker_cache_partial.hworker_env.hworker_interface.hframework
platform/default
protobuf
python
31
WORKSPACE
31
WORKSPACE
@ -15,6 +15,37 @@
|
||||
load("//tensorflow:workspace.bzl", "tf_workspace")
|
||||
tf_workspace()
|
||||
|
||||
# grpc expects //external:protobuf_clib and //external:protobuf_compiler
|
||||
# to point to the protobuf's compiler library.
|
||||
bind(
|
||||
name = "protobuf_clib",
|
||||
actual = "//google/protobuf:protoc_lib",
|
||||
)
|
||||
|
||||
bind(
|
||||
name = "protobuf_compiler",
|
||||
actual = "//google/protobuf:protoc_lib",
|
||||
)
|
||||
|
||||
git_repository(
|
||||
name = "grpc",
|
||||
commit = "73979f4",
|
||||
init_submodules = True,
|
||||
remote = "https://github.com/grpc/grpc.git",
|
||||
)
|
||||
|
||||
# protobuf expects //external:grpc_cpp_plugin to point to grpc's
|
||||
# C++ plugin code generator.
|
||||
bind(
|
||||
name = "grpc_cpp_plugin",
|
||||
actual = "@grpc//:grpc_cpp_plugin",
|
||||
)
|
||||
|
||||
bind(
|
||||
name = "grpc_lib",
|
||||
actual = "@grpc//:grpc++_unsecure",
|
||||
)
|
||||
|
||||
# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT
|
||||
|
||||
new_git_repository(
|
||||
|
@ -61,6 +61,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gpu_kernel_library")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_proto_library",
|
||||
"tf_proto_library_cc",
|
||||
"tf_additional_lib_srcs",
|
||||
"tf_additional_stream_executor_srcs",
|
||||
"tf_additional_test_deps",
|
||||
@ -77,7 +78,15 @@ load(
|
||||
|
||||
tf_proto_library(
|
||||
name = "protos_all",
|
||||
srcs = glob(["**/*.proto"]),
|
||||
srcs = glob(
|
||||
["**/*.proto"],
|
||||
exclude = [
|
||||
"protobuf/worker.proto",
|
||||
"protobuf/worker_service.proto",
|
||||
"protobuf/master.proto",
|
||||
"protobuf/master_service.proto",
|
||||
],
|
||||
),
|
||||
cc_api_version = 2,
|
||||
go_api_version = 2,
|
||||
java_api_version = 2,
|
||||
@ -85,6 +94,54 @@ tf_proto_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "worker_proto",
|
||||
srcs = ["protobuf/worker.proto"],
|
||||
cc_api_version = 2,
|
||||
cc_libs = [":protos_all_cc"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "worker_service_proto",
|
||||
srcs = ["protobuf/worker_service.proto"],
|
||||
has_services = 1,
|
||||
cc_api_version = 2,
|
||||
cc_grpc_version = 1,
|
||||
cc_libs = [":worker_proto_cc"],
|
||||
cc_stubby_versions = ["2"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "master_proto",
|
||||
srcs = ["protobuf/master.proto"],
|
||||
cc_api_version = 2,
|
||||
cc_libs = [":protos_all_cc"],
|
||||
py_api_version = 2,
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "master_service_proto",
|
||||
srcs = ["protobuf/master_service.proto"],
|
||||
has_services = 1,
|
||||
cc_api_version = 2,
|
||||
cc_grpc_version = 1,
|
||||
cc_libs = [":master_proto_cc"],
|
||||
cc_stubby_versions = ["2"],
|
||||
py_api_version = 2,
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lib",
|
||||
hdrs = [
|
||||
|
306
tensorflow/core/distributed_runtime/BUILD
Normal file
306
tensorflow/core/distributed_runtime/BUILD
Normal file
@ -0,0 +1,306 @@
|
||||
# Description:
|
||||
# A distributed runtime for TensorFlow, which allows graph execution
|
||||
# to be distributed and performed in parallel across multiple
|
||||
# processes.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "c_srcs",
|
||||
data = glob([
|
||||
"**/*.cc",
|
||||
"**/*.h",
|
||||
]),
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_tests")
|
||||
|
||||
# For platform specific build config
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
|
||||
package(default_visibility = [
|
||||
"//tensorflow:internal",
|
||||
])
|
||||
|
||||
cc_library(
|
||||
name = "worker_env",
|
||||
hdrs = ["worker_env.h"],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "worker_interface",
|
||||
hdrs = ["worker_interface.h"],
|
||||
deps = [
|
||||
":call_options",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "call_options",
|
||||
srcs = ["call_options.cc"],
|
||||
hdrs = ["call_options.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "call_options_test",
|
||||
size = "small",
|
||||
srcs = ["call_options_test.cc"],
|
||||
deps = [
|
||||
":call_options",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "worker_cache",
|
||||
hdrs = ["worker_cache.h"],
|
||||
deps = ["//tensorflow/core:protos_all_cc"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "remote_device",
|
||||
srcs = ["remote_device.cc"],
|
||||
hdrs = ["remote_device.h"],
|
||||
deps = [
|
||||
":process_util",
|
||||
":worker_cache",
|
||||
":worker_interface",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "master_interface",
|
||||
hdrs = ["master_interface.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "master",
|
||||
srcs = ["master.cc"],
|
||||
hdrs = ["master.h"],
|
||||
deps = [
|
||||
":call_options",
|
||||
":master_env",
|
||||
":master_session_interface",
|
||||
":process_util",
|
||||
":remote_device",
|
||||
":worker_cache",
|
||||
":worker_interface",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "master_session",
|
||||
srcs = ["master_session.cc"],
|
||||
hdrs = ["master_session.h"],
|
||||
deps = [
|
||||
":master_env",
|
||||
":master_session_interface",
|
||||
":process_util",
|
||||
":simple_graph_execution_state",
|
||||
":worker_cache",
|
||||
":worker_interface",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "build_graph_options",
|
||||
srcs = ["build_graph_options.cc"],
|
||||
hdrs = ["build_graph_options.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "simple_graph_execution_state",
|
||||
srcs = ["simple_graph_execution_state.cc"],
|
||||
hdrs = ["simple_graph_execution_state.h"],
|
||||
deps = [
|
||||
":build_graph_options",
|
||||
":process_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rendezvous_mgr_interface",
|
||||
srcs = [],
|
||||
hdrs = ["rendezvous_mgr_interface.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "master_session_interface",
|
||||
srcs = [],
|
||||
hdrs = ["master_session_interface.h"],
|
||||
deps = ["//tensorflow/core:lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "base_rendezvous_mgr",
|
||||
srcs = ["base_rendezvous_mgr.cc"],
|
||||
hdrs = ["base_rendezvous_mgr.h"],
|
||||
deps = [
|
||||
":process_util",
|
||||
":rendezvous_mgr_interface",
|
||||
":worker_cache",
|
||||
":worker_env",
|
||||
":worker_interface",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "master_env",
|
||||
hdrs = ["master_env.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_mgr",
|
||||
srcs = ["graph_mgr.cc"],
|
||||
hdrs = ["graph_mgr.h"],
|
||||
deps = [
|
||||
":process_util",
|
||||
":rendezvous_mgr_interface",
|
||||
":worker_env",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "process_util",
|
||||
srcs = ["process_util.cc"],
|
||||
hdrs = ["process_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "worker_cache_partial",
|
||||
srcs = ["worker_cache_partial.cc"],
|
||||
hdrs = ["worker_cache_partial.h"],
|
||||
deps = [
|
||||
":process_util",
|
||||
":worker_cache",
|
||||
":worker_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "worker_cache_logger",
|
||||
srcs = ["worker_cache_logger.cc"],
|
||||
hdrs = ["worker_cache_logger.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(mrry): Move executor_test.cc to ../common_runtime when once it no longer depends
|
||||
# on grpc_testlib.
|
||||
tf_cc_tests(
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags(),
|
||||
tests = [
|
||||
"executor_test.cc",
|
||||
"master_test.cc",
|
||||
"remote_device_test.cc",
|
||||
],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
":master",
|
||||
":process_util",
|
||||
":remote_device",
|
||||
":worker_interface",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:master_service_proto_cc",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||
],
|
||||
)
|
197
tensorflow/core/distributed_runtime/README.md
Normal file
197
tensorflow/core/distributed_runtime/README.md
Normal file
@ -0,0 +1,197 @@
|
||||
# Distributed TensorFlow
|
||||
|
||||
This directory contains the initial open-source implementation of the
|
||||
distributed TensorFlow runtime, using [gRPC](http://grpc.io) for inter-process
|
||||
communication.
|
||||
|
||||
## Quick start
|
||||
|
||||
To get started, you will need to build the TensorFlow server binary
|
||||
(`grpc_tensorflow_server`) and a gRPC-based client. Currently this is only
|
||||
available using the source-based installation of TensorFlow, but it will be
|
||||
included in future binary releases. You can build the server binary using one of
|
||||
the following commands:
|
||||
|
||||
```shell
|
||||
# CPU-only build.
|
||||
$ bazel build -c opt //tensorflow/core/distributed_runtime/rpc:grpc_tensorflow_server
|
||||
|
||||
# GPU build.
|
||||
$ bazel build -c opt --config=cuda //tensorflow/core/distributed_runtime/rpc:grpc_tensorflow_server
|
||||
```
|
||||
|
||||
If you build the latest Python (PIP) package from source, it will contain a
|
||||
gRPC-based client. If you are using a previous binary release, you may need to
|
||||
rebuild and install an up-to-date PIP package by following
|
||||
[these installation instructions](https://www.tensorflow.org/versions/master/get_started/os_setup.html#create-the-pip-package-and-install).
|
||||
|
||||
Once you have successfully built the distributed TensorFlow components, you can
|
||||
test your installation by starting a server as follows:
|
||||
|
||||
```shell
|
||||
# Start a TensorFlow server as a single-process "cluster".
|
||||
$ bazel-bin/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server \
|
||||
--cluster_spec='local|localhost:2222' --job_name=local --task_index=0 &
|
||||
```
|
||||
|
||||
...then start a Python interpreter and create a remote session:
|
||||
|
||||
```python
|
||||
$ python
|
||||
>>> import tensorflow as tf
|
||||
>>> c = tf.constant("Hello, distributed TensorFlow!")
|
||||
>>> sess = tf.Session("grpc://localhost:2222")
|
||||
>>> sess.run(c)
|
||||
'Hello, distributed TensorFlow!'
|
||||
```
|
||||
|
||||
## Cluster definition
|
||||
|
||||
The command-line arguments to `grpc_tensorflow_server` define the membership of a TensorFlow cluster. The `--cluster_spec` flag determines the set of processes in the cluster, as a list of *jobs*, each of which contains a list of *task* endpoints. All processes in the cluster must be started with the same `--cluster_spec`. Example values include:
|
||||
|
||||
<table>
|
||||
<tr><th><code>--cluster_spec='...'</code></th><th>Available tasks</th>
|
||||
<tr>
|
||||
<td><code>local|localhost:2222</code></td><td><code>/job:local/task:0</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><code>local|localhost:2222;localhost:2223</code></td><td><code>/job:local/task:0</code><br/><code>/job:local/task:1</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><code>worker|worker0:2222;worker1:2222;worker2:2222,</code><br/><code>ps|ps0:2222;ps1:2222</code></td><td><code>/job:worker/task:0</code><br/><code>/job:worker/task:1</code><br/><code>/job:worker/task:2</code><br/><code>/job:ps/task:0</code><br/><code>/job:ps/task:1</code></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
The `--job_name` and `--task_index` flags indicate which task will run in this
|
||||
process, out of the jobs and tasks defined in `--cluster_spec`. For example,
|
||||
`--job_name=local --task_index=0` means that the process will be task
|
||||
`/job:local/task:0`, and TensorFlow devices in the process will have names
|
||||
starting with that prefix.
|
||||
|
||||
**N.B.** Manually specifying these command lines can be tedious, especially for
|
||||
large clusters. We are working on tools for launching tasks programmatically,
|
||||
e.g. using a cluster manager like [Kubernetes](http://kubernetes.io). If there
|
||||
are particular cluster managers for which you'd like to see support, please
|
||||
raise a [GitHub issue](https://github.com/tensorflow/tensorflow/issues).
|
||||
|
||||
## Specifying distributed devices in your model
|
||||
|
||||
To place operations on a particular process, you can use the same
|
||||
[`tf.device()`](https://www.tensorflow.org/versions/master/api_docs/python/framework.html#device)
|
||||
function that is used to specify whether ops run on the CPU or GPU. For example:
|
||||
|
||||
```python
|
||||
with tf.device("/job:ps/task:0"):
|
||||
weights_1 = tf.Variable(...)
|
||||
biases_1 = tf.Variable(...)
|
||||
|
||||
with tf.device("/job:ps/task:1"):
|
||||
weights_2 = tf.Variable(...)
|
||||
biases_2 = tf.Variable(...)
|
||||
|
||||
with tf.device("/job:worker/task:7"):
|
||||
input, labels = ...
|
||||
layer_1 = tf.nn.relu(tf.matmul(input, weights_1) + biases_1)
|
||||
logits = tf.nn.relu(tf.matmul(layer_1, weights_2) + biases_2)
|
||||
# ...
|
||||
train_op = ...
|
||||
|
||||
with tf.Session("grpc://worker7:2222") as sess:
|
||||
for _ in range(10000):
|
||||
sess.run(train_op)
|
||||
```
|
||||
|
||||
In the above example, the variables are created on two tasks in the `ps` job,
|
||||
and the compute-intensive part of the model is created in the `worker`
|
||||
job. TensorFlow will insert the appropriate data transfers between the jobs
|
||||
(from `ps` to `worker` for the forward pass, and from `worker` to `ps` for
|
||||
applying gradients).
|
||||
|
||||
## Replicated training
|
||||
|
||||
A common training configuration ("data parallel training") involves multiple
|
||||
tasks in a `worker` job training the same model, using shared parameters hosted
|
||||
in a one or more tasks in a `ps` job. Each task will typically run on a
|
||||
different machine. There are many ways to specify this structure in TensorFlow,
|
||||
and we are building libraries that will simplify the work of specifying a
|
||||
replicated model. Possible approaches include:
|
||||
|
||||
* Building a single graph containing one set of parameters (in `tf.Variable`
|
||||
nodes pinned to `/job:ps`), and multiple copies of the "model" pinned to
|
||||
different tasks in `/job:worker`. Each copy of the model can have a different
|
||||
`train_op`, and one or more client threads can call `sess.run(train_ops[i])`
|
||||
for each worker `i`. This implements *asynchronous* training.
|
||||
|
||||
This approach uses a single `tf.Session` whose target is one of the workers in
|
||||
the cluster.
|
||||
|
||||
* As above, but where the gradients from all workers are averaged. See the
|
||||
[CIFAR-10 multi-GPU trainer](https://www.tensorflow.org/code/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py)
|
||||
for an example of this form of replication. The implements *synchronous* training
|
||||
|
||||
* The "distributed trainer" approach uses multiple graphs—one per
|
||||
worker—where each graph contains one set of parameters (pinned to
|
||||
`/job:ps`) and one copy of the model (pinned to a particular
|
||||
`/job:worker/task:i`). The "container" mechanism is used to share variables
|
||||
between different graphs: when each variable is constructed, the optional
|
||||
`container` argument is specified with the same value in each copy of the
|
||||
graph. For large models, this can be more efficient, because the overall graph
|
||||
is smaller.
|
||||
|
||||
This approach uses multiple `tf.Session` objects: one per worker process,
|
||||
where the `target` of each is the address of a different worker. The
|
||||
`tf.Session` objects can all be created in a single Python client, or you can
|
||||
use multiple Python clients to better distribute the trainer load.
|
||||
|
||||
## Glossary
|
||||
|
||||
<dl>
|
||||
<dt>Client</dt>
|
||||
<dd>
|
||||
A client is typically a program that builds a TensorFlow graph and
|
||||
constructs a `tensorflow::Session` to interact with a cluster. Clients are
|
||||
typically written in Python or C++. A single client process can directly
|
||||
interact with multiple TensorFlow servers (see "Replicated training" above),
|
||||
and a single server can serve multiple clients.
|
||||
</dd>
|
||||
<dt>Cluster</dt>
|
||||
<dd>
|
||||
A TensorFlow cluster comprises one or more TensorFlow servers, divided into
|
||||
a set of named jobs, which in turn comprise lists of tasks. A cluster is
|
||||
typically dedicated to a particular high-level objective, such as training a
|
||||
neural network, using many machines in parallel.
|
||||
</dd>
|
||||
<dt>Job</dt>
|
||||
<dd>
|
||||
A job comprises a list of "tasks", which typically serve a common
|
||||
purpose. For example, a job named `ps` (for "parameter server") typically
|
||||
hosts nodes that store and update variables; while a job named `worker`
|
||||
typically hosts stateless nodes that perform compute-intensive tasks.
|
||||
The tasks in a job typically run on different machines.
|
||||
</dd>
|
||||
<dt>Master service</dt>
|
||||
<dd>
|
||||
An RPC service that provides remote access to a set of distributed
|
||||
devices. The master service implements the <code>tensorflow::Session</code>
|
||||
interface, and is responsible for coordinating work across one or more
|
||||
"worker services".
|
||||
</dd>
|
||||
<dt>Task</dt>
|
||||
<dd>
|
||||
A task typically corresponds to a single TensorFlow server process,
|
||||
belonging to a particular "job" and with a particular index within that
|
||||
job's list of tasks.
|
||||
</dd>
|
||||
|
||||
<dt>TensorFlow server</dt>
|
||||
<dd>
|
||||
A process running the <code>grpc_tensorflow_server</code> binary, which is a
|
||||
member of a cluster, and exports a "master service" and "worker service".
|
||||
</dd>
|
||||
<dt>Worker service</dt>
|
||||
<dd>
|
||||
An RPC service that executes parts of a TensorFlow graph using its local
|
||||
devices. A worker service implements <a
|
||||
href="./worker_service.proto"><code>worker_service.proto</code></a>.
|
||||
</dd>
|
||||
</dl>
|
318
tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
Normal file
318
tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
Normal file
@ -0,0 +1,318 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* env) : worker_env_(env) {}
|
||||
|
||||
BaseRendezvousMgr::~BaseRendezvousMgr() {
|
||||
for (auto& p : table_) {
|
||||
BaseRemoteRendezvous* rendez = p.second;
|
||||
rendez->StartAbort(errors::Aborted("Shutdown"));
|
||||
rendez->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
Rendezvous* BaseRendezvousMgr::Find(int64 step_id) {
|
||||
return FindOrCreate(step_id);
|
||||
}
|
||||
|
||||
BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
|
||||
mutex_lock l(mu_);
|
||||
Table::iterator iter = table_.find(step_id);
|
||||
if (iter == table_.end()) {
|
||||
auto rr = Create(step_id, worker_env_);
|
||||
iter = table_.insert({step_id, rr}).first;
|
||||
}
|
||||
iter->second->Ref();
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, const string& key,
|
||||
Rendezvous::DoneCallback done) {
|
||||
BaseRemoteRendezvous* rendez = FindOrCreate(step_id);
|
||||
rendez->RecvLocalAsync(
|
||||
key, [rendez, done](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& v,
|
||||
bool dead) {
|
||||
rendez->Unref();
|
||||
done(s, send_args, recv_args, v, dead);
|
||||
});
|
||||
}
|
||||
|
||||
Status BaseRendezvousMgr::RecvLocal(int64 step_id, const string& key,
|
||||
Tensor* val, bool* is_dead) {
|
||||
Status ret;
|
||||
Notification n;
|
||||
RecvLocalAsync(step_id, key,
|
||||
[val, is_dead, &ret, &n](const Status& s,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& v, const bool dead) {
|
||||
ret = s;
|
||||
*val = v;
|
||||
*is_dead = dead;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void BaseRendezvousMgr::Cleanup(int64 step_id) {
|
||||
Rendezvous* rendez = nullptr;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
Table::iterator iter = table_.find(step_id);
|
||||
if (iter != table_.end()) {
|
||||
rendez = iter->second;
|
||||
table_.erase(iter);
|
||||
}
|
||||
}
|
||||
if (!rendez) return;
|
||||
rendez->StartAbort(errors::Aborted("Cleanup ", step_id));
|
||||
rendez->Unref();
|
||||
}
|
||||
|
||||
void BaseRendezvousMgr::CleanupAll() {
|
||||
std::vector<Rendezvous*> rendezs;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
for (const auto& entry : table_) {
|
||||
rendezs.push_back(entry.second);
|
||||
}
|
||||
table_.clear();
|
||||
}
|
||||
for (auto rendez : rendezs) {
|
||||
rendez->StartAbort(errors::Aborted("Shutdown"));
|
||||
rendez->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
|
||||
bool tolerate_dup_recv)
|
||||
: env_(env),
|
||||
step_id_(step_id),
|
||||
tolerate_dup_recv_(tolerate_dup_recv),
|
||||
local_(NewLocalRendezvous(tolerate_dup_recv)) {}
|
||||
|
||||
BaseRemoteRendezvous::~BaseRemoteRendezvous() {
|
||||
CHECK(active_.empty());
|
||||
local_->Unref();
|
||||
}
|
||||
|
||||
// Returns true if "device_name" is a valid full name of local device
|
||||
// of the "worker". This helper is purely based on the worker name
|
||||
// and device name and does no lookups in the worker->device_mgr.
|
||||
static bool IsLocalDevice(const WorkerEnv& worker,
|
||||
const StringPiece device_name) {
|
||||
return device_name.starts_with(worker.worker_name);
|
||||
}
|
||||
|
||||
Status BaseRemoteRendezvous::Send(const string& key,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) {
|
||||
VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << key;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) return status_;
|
||||
}
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||
if (!IsLocalDevice(*env_, parsed.src_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
|
||||
env_->worker_name);
|
||||
}
|
||||
// Buffers "val" and "device_context" in local_.
|
||||
return local_->Send(key, args, val, is_dead);
|
||||
}
|
||||
|
||||
Status BaseRemoteRendezvous::ParseKey(const string& key, bool is_src,
|
||||
Rendezvous::ParsedKey* parsed) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) return status_;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed));
|
||||
if (is_src && !IsLocalDevice(*env_, parsed->src_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
|
||||
env_->worker_name);
|
||||
}
|
||||
if (!is_src && !IsLocalDevice(*env_, parsed->dst_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (dst): ", key, " @ ",
|
||||
env_->worker_name);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::SameWorkerRecvDone(
|
||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
|
||||
StatusCallback done) {
|
||||
// Do a quick copy (sharing the underlying buffer) if both tensors
|
||||
// are on host memory.
|
||||
const bool src_host =
|
||||
(send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
|
||||
const bool dst_host =
|
||||
(recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
|
||||
if (src_host && dst_host) {
|
||||
*out = in;
|
||||
done(Status::OK());
|
||||
return;
|
||||
}
|
||||
|
||||
// This copy must involve a GPU. Hence, "in" must support DMA
|
||||
// (e.g., string tensors do not work on GPU).
|
||||
if (!DMAHelper::CanUseDMA(&in)) {
|
||||
done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()),
|
||||
" tensor may not be copied from/to a GPU."));
|
||||
return;
|
||||
}
|
||||
|
||||
Device* src_device;
|
||||
Status s = env_->device_mgr->LookupDevice(parsed.src_device, &src_device);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
Device* dst_device;
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
AllocatorAttributes attr = recv_args.alloc_attrs;
|
||||
attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
|
||||
recv_args.alloc_attrs.gpu_compatible());
|
||||
Allocator* out_allocator = dst_device->GetAllocator(attr);
|
||||
Tensor copy(out_allocator, in.dtype(), in.shape());
|
||||
*out = copy;
|
||||
|
||||
// The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
|
||||
// etc.
|
||||
CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
|
||||
recv_args.device_context, src_device, dst_device,
|
||||
send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
|
||||
done);
|
||||
}
|
||||
|
||||
bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
|
||||
DeviceNameUtils::ParsedName dst) {
|
||||
return DeviceNameUtils::IsSameAddressSpace(src, dst);
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::RecvAsync(const string& key,
|
||||
const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
VLOG(1) << "RemoteRendezvous Recv " << this << " " << key;
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = ParseKey(key, false /*!is_src*/, &parsed);
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Are src and dst in the same worker?
|
||||
if (IsSameWorker(parsed.src, parsed.dst)) {
|
||||
// Recv the tensor from local_.
|
||||
local_->RecvAsync(
|
||||
key, recv_args, [this, parsed, done](const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& in, bool is_dead) {
|
||||
Status s = status;
|
||||
Tensor* out = new Tensor;
|
||||
StatusCallback final_callback = [done, send_args, recv_args, out,
|
||||
is_dead](const Status& s) {
|
||||
done(s, send_args, recv_args, *out, is_dead);
|
||||
delete out;
|
||||
};
|
||||
|
||||
if (s.ok()) {
|
||||
SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
|
||||
final_callback);
|
||||
} else {
|
||||
final_callback(s);
|
||||
}
|
||||
});
|
||||
return;
|
||||
} else {
|
||||
RecvFromRemoteAsync(key, parsed, recv_args, done);
|
||||
}
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::RecvLocalAsync(const string& key,
|
||||
DoneCallback done) {
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = ParseKey(key, true /* is_src */, &parsed);
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), Args(), Tensor(), false);
|
||||
return;
|
||||
}
|
||||
local_->RecvAsync(key, Args(), done);
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::StartAbort(const Status& s) {
|
||||
CHECK(!s.ok());
|
||||
local_->StartAbort(s);
|
||||
{
|
||||
// Aborts all active RecvTensor calls.
|
||||
mutex_lock l(mu_);
|
||||
if (status_.ok()) {
|
||||
status_ = s;
|
||||
for (BaseRecvTensorCall* call : active_) {
|
||||
call->StartAbort(s);
|
||||
}
|
||||
active_.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call) {
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) {
|
||||
call->StartAbort(status_);
|
||||
} else {
|
||||
CHECK(active_.insert(call).second);
|
||||
}
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
|
||||
mutex_lock l(mu_);
|
||||
active_.erase(call);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
212
tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
Normal file
212
tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
Normal file
@ -0,0 +1,212 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class BaseRemoteRendezvous;
|
||||
class BaseRecvTensorCall;
|
||||
|
||||
// RendezvousMgr keeps track of a set of local rendezvous instances.
|
||||
// All tensors sent by this worker are buffered in a RendezvousMgr
|
||||
// until the tensor is received. Each global unique "step_id"
|
||||
// corresponds to one local rendezvous instance managed by a
|
||||
// RendezvousMgr.
|
||||
//
|
||||
// E.g.,
|
||||
// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
|
||||
// fork execution of a graph executor using "rendez" on thread 1;
|
||||
// fork execution of another graph executor using "rendez" on thread 2;
|
||||
// ...
|
||||
// join threads 1 and 2;
|
||||
//
|
||||
// In the example above, execution in thread 1 and 2 communicates with
|
||||
// each other by send/recv operations through `rendez`.
|
||||
//
|
||||
// Tensors sent and received through a rendezvous managed by this
|
||||
// RendezvousMgr must have keys generated by Rendezvous::CreateKey().
|
||||
class BaseRendezvousMgr : public RendezvousMgrInterface {
|
||||
public:
|
||||
explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
|
||||
~BaseRendezvousMgr() override;
|
||||
|
||||
// Returns Rendezvous supporting send and recv among workers in the
|
||||
// "step_id". The caller takes ownership of one reference on the
|
||||
// returned Rendezvous instance.
|
||||
Rendezvous* Find(int64 step_id) override;
|
||||
|
||||
// Finds the local rendezvous instance for the "step_id". Runs
|
||||
// "done" when the tensor for "key" is produced or an error occurs.
|
||||
//
|
||||
// This method is used by the rpc handler of RecvTensor.
|
||||
void RecvLocalAsync(int64 step_id, const string& key,
|
||||
Rendezvous::DoneCallback done) override;
|
||||
|
||||
// Synchronous wrapper for RecvLocalAsync.
|
||||
Status RecvLocal(int64 step_id, const string& key, Tensor* val,
|
||||
bool* is_dead) override;
|
||||
|
||||
// Removes rendezvous for "step_id".
|
||||
//
|
||||
// TODO(zhifengc): Have a background thread in worker that
|
||||
// periodically calls CleanupAll().
|
||||
void Cleanup(int64 step_id) override;
|
||||
|
||||
// Removed all rendezvous.
|
||||
void CleanupAll() override;
|
||||
|
||||
protected:
|
||||
virtual BaseRemoteRendezvous* Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) = 0;
|
||||
|
||||
private:
|
||||
// Maps step_id to rendezvous.
|
||||
typedef std::unordered_map<int64, BaseRemoteRendezvous*> Table;
|
||||
|
||||
// Not owned.
|
||||
const WorkerEnv* const worker_env_;
|
||||
|
||||
mutex mu_;
|
||||
Table table_ GUARDED_BY(mu_);
|
||||
|
||||
BaseRemoteRendezvous* FindOrCreate(int64 step_id);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr);
|
||||
};
|
||||
|
||||
// RemoteRendezvous is a Rendezvous which can handle either
|
||||
// the producer or consumer being in a remote process.
|
||||
//
|
||||
// Buffering of Tensor values is delegated to a "local" Rendezvous
|
||||
// obtained from NewLocalRendezvous(). This class just adds
|
||||
// functionality to coordinate with remote workers.
|
||||
class BaseRemoteRendezvous : public Rendezvous {
|
||||
public:
|
||||
BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
|
||||
bool tolerate_dup_recv);
|
||||
|
||||
// Forwards to local_, where the Tensor "val" will be buffered and
|
||||
// any waiting callback stored.
|
||||
Status Send(const string& key, const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) override;
|
||||
|
||||
// This method is called only by the RecvOp. It tests to see
|
||||
// whether the value will be produced by a local or remote device
|
||||
// and handles accordingly. In the local case it forwards to
|
||||
// local_, in the remote case it initiates an RPC request.
|
||||
void RecvAsync(const string& key, const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
void StartAbort(const Status& status) override;
|
||||
|
||||
// This method is called only by the local Worker, forwarded through
|
||||
// the same method on RendezvousMgr. This occurs when the Worker
|
||||
// has received a RecvTensor request, either locally or over the
|
||||
// network. In either case it needs to retrieve a locally buffered
|
||||
// value from local_, and give it to its caller.
|
||||
//
|
||||
// Runs "done" as soon as the tensor for "key" is available or an error
|
||||
// is detected.
|
||||
//
|
||||
// REQUIRES: "key" is one that will be Saved into the local rendezvous.
|
||||
void RecvLocalAsync(const string& key, DoneCallback done);
|
||||
|
||||
protected:
|
||||
virtual void RecvFromRemoteAsync(const string& key,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) = 0;
|
||||
|
||||
// Returns true if "src" and "dst" are located in the same worker,
|
||||
// and hence may use a local rendezvous.
|
||||
virtual bool IsSameWorker(DeviceNameUtils::ParsedName src,
|
||||
DeviceNameUtils::ParsedName dst);
|
||||
|
||||
// If aborted, aborts "call". Otherwise, adds "call" into active_.
|
||||
void RegisterCall(BaseRecvTensorCall* call);
|
||||
|
||||
// Removes "call" from active_ if "call" is in active_.
|
||||
void DeregisterCall(BaseRecvTensorCall* call);
|
||||
|
||||
~BaseRemoteRendezvous() override;
|
||||
|
||||
const WorkerEnv* const env_; // Not owned.
|
||||
const int64 step_id_;
|
||||
|
||||
private:
|
||||
const bool tolerate_dup_recv_;
|
||||
Rendezvous* local_; // Owns a Ref on this object.
|
||||
|
||||
mutable mutex mu_;
|
||||
|
||||
// Status given by StartAbort() if any.
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
|
||||
// Active outstanding RecvTensor calls.
|
||||
std::unordered_set<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
|
||||
|
||||
// Parses "key" into "parsed". If "is_src" is true, checks that the
|
||||
// rendezvous key's source is in this process. If "is_src" is false,
|
||||
// checks that the rendezvous key's destination is in this process.
|
||||
Status ParseKey(const string& key, bool is_src,
|
||||
Rendezvous::ParsedKey* parsed);
|
||||
|
||||
// Callback handling the case when a rendezvous has been
|
||||
// accomplished in local_ and the consumer is local to this process.
|
||||
// Tensor "in" will be copied into "out". The key "parsed" encodes
|
||||
// the src and dst devices.
|
||||
void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& in_args,
|
||||
const Rendezvous::Args& out_args, const Tensor& in,
|
||||
Tensor* out, StatusCallback done);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
|
||||
};
|
||||
|
||||
class BaseRecvTensorCall {
|
||||
public:
|
||||
BaseRecvTensorCall() {}
|
||||
virtual ~BaseRecvTensorCall() {}
|
||||
|
||||
virtual void Start(std::function<void()> recv_done) = 0;
|
||||
|
||||
virtual void StartAbort(const Status& s) = 0;
|
||||
|
||||
virtual Status status() const = 0;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall);
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
|
38
tensorflow/core/distributed_runtime/build_graph_options.cc
Normal file
38
tensorflow/core/distributed_runtime/build_graph_options.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/build_graph_options.h"
|
||||
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
string BuildGraphOptions::DebugString() const {
|
||||
string rv = "Feed endpoints: ";
|
||||
for (auto& s : feed_endpoints) {
|
||||
strings::StrAppend(&rv, s, ", ");
|
||||
}
|
||||
strings::StrAppend(&rv, "\nFetch endpoints: ");
|
||||
for (auto& s : fetch_endpoints) {
|
||||
strings::StrAppend(&rv, s, ", ");
|
||||
}
|
||||
strings::StrAppend(&rv, "\nTarget nodes: ");
|
||||
for (auto& s : target_nodes) {
|
||||
strings::StrAppend(&rv, s, ", ");
|
||||
}
|
||||
return rv;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
38
tensorflow/core/distributed_runtime/build_graph_options.h
Normal file
38
tensorflow/core/distributed_runtime/build_graph_options.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
struct BuildGraphOptions {
|
||||
std::vector<string> feed_endpoints;
|
||||
std::vector<string> fetch_endpoints;
|
||||
|
||||
// TODO(vrv): Remove this when we unify target_nodes and fetch_endpoint,
|
||||
// the former via "ref" fetch_endpoints.
|
||||
std::vector<string> target_nodes;
|
||||
|
||||
string DebugString() const;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_
|
44
tensorflow/core/distributed_runtime/call_options.cc
Normal file
44
tensorflow/core/distributed_runtime/call_options.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
CallOptions::CallOptions() {}
|
||||
|
||||
void CallOptions::StartCancel() {
|
||||
mutex_lock l(mu_);
|
||||
if (cancel_func_ != nullptr) {
|
||||
// NOTE: We must call the cancel_func_ with mu_ held. This ensure
|
||||
// that ClearCancelCallback() does not race with StartCancel().
|
||||
cancel_func_();
|
||||
// NOTE: We can clear cancel_func_ if needed.
|
||||
}
|
||||
}
|
||||
|
||||
void CallOptions::SetCancelCallback(CancelFunction cancel_func) {
|
||||
mutex_lock l(mu_);
|
||||
cancel_func_ = cancel_func;
|
||||
}
|
||||
|
||||
void CallOptions::ClearCancelCallback() {
|
||||
mutex_lock l(mu_);
|
||||
cancel_func_ = nullptr;
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
72
tensorflow/core/distributed_runtime/call_options.h
Normal file
72
tensorflow/core/distributed_runtime/call_options.h
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Options passed to interface calls. This class provides portable
|
||||
// functionality across different RPC systems on top of
|
||||
// platform-specific mechanisms (for client and server contexts,
|
||||
// cancellation, etc.).
|
||||
//
|
||||
// TODO(zhifengc): Maybe change all RPC methods to take CallOptions.
|
||||
class CallOptions {
|
||||
public:
|
||||
CallOptions();
|
||||
|
||||
// Cancellation.
|
||||
//
|
||||
// The caller may call StartCancel() anytime as long as this
|
||||
// CallOptions object is alive. The callee may or may not receive
|
||||
// the cancellation notification depending on the rpc layer
|
||||
// implementation.
|
||||
void StartCancel();
|
||||
|
||||
// The callee (the rpc layer implementation) must set a cancellation
|
||||
// notifier before its blocking operation and clear the notifier
|
||||
// before the call returns.
|
||||
//
|
||||
// "cancel_func" may be called zero, once or more time. Therefore, it
|
||||
// should _not_ be responsible for memory management of any objects.
|
||||
//
|
||||
// "cancel_func" must be very light-weight. It should not block on
|
||||
// IO or locking. Typically, it just calls the rpc implementation
|
||||
// layer's specific cancellation mechanism and does nothing else.
|
||||
//
|
||||
// NOTE: "cancel_func" itself is pass-by-value. Therefore, we do not
|
||||
// worry about its ownership here.
|
||||
typedef std::function<void()> CancelFunction;
|
||||
void SetCancelCallback(CancelFunction cancel_func);
|
||||
void ClearCancelCallback();
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
CancelFunction cancel_func_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CallOptions);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_
|
39
tensorflow/core/distributed_runtime/call_options_test.cc
Normal file
39
tensorflow/core/distributed_runtime/call_options_test.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(CallOptions, Cancel) {
|
||||
int num_calls = 0;
|
||||
CallOptions opts;
|
||||
opts.StartCancel();
|
||||
EXPECT_EQ(num_calls, 0);
|
||||
opts.SetCancelCallback([&num_calls]() { num_calls++; });
|
||||
EXPECT_EQ(num_calls, 0);
|
||||
opts.StartCancel();
|
||||
EXPECT_EQ(num_calls, 1);
|
||||
opts.StartCancel();
|
||||
EXPECT_EQ(num_calls, 2);
|
||||
opts.ClearCancelCallback();
|
||||
EXPECT_EQ(num_calls, 2);
|
||||
opts.StartCancel();
|
||||
EXPECT_EQ(num_calls, 2);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
407
tensorflow/core/distributed_runtime/executor_test.cc
Normal file
407
tensorflow/core/distributed_runtime/executor_test.cc
Normal file
@ -0,0 +1,407 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class ExecutorTest : public ::testing::Test {
|
||||
protected:
|
||||
ExecutorTest()
|
||||
: device_(DeviceFactory::NewDevice("CPU", {},
|
||||
"/job:localhost/replica:0/task:0")),
|
||||
|
||||
step_stats_collector_(&step_stats_) {
|
||||
SessionOptions options;
|
||||
thread_pool_ = ComputePool(options);
|
||||
}
|
||||
|
||||
~ExecutorTest() override {
|
||||
// There should always be exactly one Ref left on the Rendezvous
|
||||
// when the test completes.
|
||||
CHECK(rendez_->Unref());
|
||||
delete exec_;
|
||||
delete device_;
|
||||
}
|
||||
|
||||
// Resets executor_ with a new executor based on a graph 'gdef'.
|
||||
void Create(const Graph* graph) {
|
||||
const int version = graph->versions().producer();
|
||||
LocalExecutorParams params;
|
||||
params.device = device_;
|
||||
params.create_kernel = [this, version](const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
|
||||
};
|
||||
params.delete_kernel = [](OpKernel* kernel) {
|
||||
DeleteNonCachedKernel(kernel);
|
||||
};
|
||||
delete exec_;
|
||||
TF_CHECK_OK(NewLocalExecutor(params, graph, &exec_));
|
||||
runner_ = [this](std::function<void()> fn) { thread_pool_->Schedule(fn); };
|
||||
rendez_ = NewLocalRendezvous();
|
||||
}
|
||||
|
||||
Status Run(Rendezvous* rendez) {
|
||||
Executor::Args args;
|
||||
args.rendezvous = rendez;
|
||||
args.stats_collector = &step_stats_collector_;
|
||||
args.runner = runner_;
|
||||
return exec_->Run(args);
|
||||
}
|
||||
|
||||
thread::ThreadPool* thread_pool_ = nullptr;
|
||||
Device* device_ = nullptr;
|
||||
Executor* exec_ = nullptr;
|
||||
StepStatsCollector step_stats_collector_;
|
||||
StepStats step_stats_;
|
||||
Executor::Args::Runner runner_;
|
||||
Rendezvous* rendez_ = nullptr;
|
||||
};
|
||||
|
||||
// A float val -> Tensor<float>
|
||||
Tensor V(const float val) {
|
||||
Tensor tensor(DT_FLOAT, TensorShape({}));
|
||||
tensor.scalar<float>()() = val;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// A int32 val -> Tensor<int32>
|
||||
Tensor VI(const int32 val) {
|
||||
Tensor tensor(DT_INT32, TensorShape({}));
|
||||
tensor.scalar<int32>()() = val;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// A bool val -> Tensor<bool>
|
||||
Tensor VB(const bool val) {
|
||||
Tensor tensor(DT_BOOL, TensorShape({}));
|
||||
tensor.scalar<bool>()() = val;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// A double val -> Tensor<double>
|
||||
Tensor VD(const double val) {
|
||||
Tensor tensor(DT_DOUBLE, TensorShape({}));
|
||||
tensor.scalar<double>()() = val;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Tensor<float> -> a float val.
|
||||
float V(const Tensor& tensor) {
|
||||
CHECK_EQ(tensor.dtype(), DT_FLOAT);
|
||||
CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
|
||||
return tensor.scalar<float>()();
|
||||
}
|
||||
|
||||
static uint64 kIncarnation = 1; // Uses in following tests.
|
||||
|
||||
string Key(const string& sender, const uint64 incarnation,
|
||||
const string& receiver, const string& name) {
|
||||
return Rendezvous::CreateKey(sender, incarnation, receiver, name,
|
||||
FrameAndIter(0, 0));
|
||||
}
|
||||
|
||||
#define ALICE "/job:j/replica:0/task:0/cpu:0"
|
||||
#define BOB "/job:j/replica:0/task:0/gpu:0"
|
||||
|
||||
TEST_F(ExecutorTest, SimpleAdd) {
|
||||
// c = a + b
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
|
||||
auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB);
|
||||
auto tmp = test::graph::Add(g, in0, in1);
|
||||
test::graph::Send(g, tmp, "c", BOB, 1, ALICE);
|
||||
Create(g);
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
|
||||
false)); // in0 = 1.0
|
||||
TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), args, V(1.0),
|
||||
false)); // in1 = 1.0
|
||||
TF_ASSERT_OK(Run(rendez_));
|
||||
Tensor out = V(-1);
|
||||
bool is_dead = false;
|
||||
TF_ASSERT_OK(
|
||||
rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
|
||||
EXPECT_EQ(2.0, V(out)); // out = 1.0 + 1.0 = 2.0
|
||||
}
|
||||
|
||||
TEST_F(ExecutorTest, SelfAdd) {
|
||||
// v0 <- a
|
||||
// v1 = v0 + v0
|
||||
// v2 = v1 + v1
|
||||
// ... ...
|
||||
// v10 = v9 + v9
|
||||
//
|
||||
// b <- v10
|
||||
// All nodes are executed by one thread.
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
auto v = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
|
||||
const int N = 10;
|
||||
for (int i = 1; i <= N; ++i) {
|
||||
v = test::graph::Add(g, v, v);
|
||||
}
|
||||
// out <- v10
|
||||
test::graph::Send(g, v, "b", BOB, 1, ALICE);
|
||||
Create(g);
|
||||
Rendezvous::Args args;
|
||||
// a = 1.0
|
||||
TF_ASSERT_OK(
|
||||
rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false));
|
||||
TF_ASSERT_OK(Run(rendez_));
|
||||
Tensor out = V(-1);
|
||||
bool is_dead = false;
|
||||
TF_ASSERT_OK(
|
||||
rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead));
|
||||
EXPECT_EQ(1024.0, V(out)); // b=v10=2*v9=4*v8=...=1024*a=1024.0
|
||||
}
|
||||
|
||||
// Builds a graph which adds N copies of one variable "in". I.e.,
|
||||
// a + a + a + ... + a
|
||||
// The returned graph is parenthesized ramdonly. I.e.,
|
||||
// a + ((a + a) + a)
|
||||
// (a + a) + (a + a)
|
||||
// ((a + a) + a) + a
|
||||
// are all possibly generated.
|
||||
void BuildTree(int N, Graph* g) {
|
||||
CHECK_GT(N, 1);
|
||||
// A single input node "in".
|
||||
auto in = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
|
||||
std::vector<Node*> nodes;
|
||||
int i = 0;
|
||||
// Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
|
||||
for (; i < N; ++i) {
|
||||
nodes.push_back(test::graph::Identity(g, in, 0));
|
||||
}
|
||||
random::PhiloxRandom philox(testing::RandomSeed(), 17);
|
||||
random::SimplePhilox rnd(&philox);
|
||||
while (nodes.size() > 1) {
|
||||
// Randomly pick two from nodes and add them. The resulting node
|
||||
// is named lik n10, n11, .... and is put back into "nodes".
|
||||
int x = rnd.Uniform(nodes.size());
|
||||
auto in0 = nodes[x];
|
||||
nodes[x] = nodes.back();
|
||||
nodes.resize(nodes.size() - 1);
|
||||
x = rnd.Uniform(nodes.size());
|
||||
auto in1 = nodes[x];
|
||||
// node = in0 + in1.
|
||||
nodes[x] = test::graph::Add(g, in0, in1);
|
||||
}
|
||||
// The final output node "out".
|
||||
test::graph::Send(g, nodes.back(), "b", BOB, 1, ALICE);
|
||||
}
|
||||
|
||||
TEST_F(ExecutorTest, RandomTree) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
BuildTree(4096, g);
|
||||
Create(g);
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(
|
||||
rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false));
|
||||
TF_ASSERT_OK(Run(rendez_));
|
||||
Tensor out = V(-1);
|
||||
bool is_dead = false;
|
||||
TF_ASSERT_OK(
|
||||
rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead));
|
||||
EXPECT_EQ(4096.0, V(out));
|
||||
}
|
||||
|
||||
void BuildConcurrentAddAssign(Graph* g) {
|
||||
auto one = test::graph::Constant(g, V(1.0));
|
||||
// A variable holds one float.
|
||||
auto var = test::graph::Var(g, DT_FLOAT, TensorShape({}));
|
||||
// Initilize the variable with 1.0.
|
||||
auto init = test::graph::Assign(g, var, one);
|
||||
// Output
|
||||
auto out = test::graph::Send(g, var, "out", ALICE, kIncarnation, BOB);
|
||||
// Have many concurrent computation. Each does v = v + 1.
|
||||
for (int i = 0; i < 1024; ++i) {
|
||||
auto add = test::graph::Add(g, var, one);
|
||||
g->AddControlEdge(init, add); // Ensures run after init.
|
||||
auto assign = test::graph::Assign(g, var, add);
|
||||
g->AddControlEdge(assign, out);
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef THREAD_SANITIZER
|
||||
TEST_F(ExecutorTest, ConcurrentAddAssign) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
BuildConcurrentAddAssign(g);
|
||||
Create(g);
|
||||
for (int iters = 0; iters < 16; ++iters) {
|
||||
Rendezvous* rendez = NewLocalRendezvous();
|
||||
TF_ASSERT_OK(Run(rendez));
|
||||
Rendezvous::Args args;
|
||||
Tensor out;
|
||||
bool is_dead;
|
||||
TF_ASSERT_OK(rendez->Recv(Key(ALICE, kIncarnation, BOB, "out"), args, &out,
|
||||
&is_dead));
|
||||
VLOG(1) << "Get " << V(out);
|
||||
EXPECT_LE(V(out), 1025.0);
|
||||
rendez->Unref();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(ExecutorTest, SimpleSwitchLive) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
|
||||
auto in1 = test::graph::Constant(g, VB(false));
|
||||
auto tmp = test::graph::Switch(g, in0, in1);
|
||||
test::graph::Send(g, tmp, "c", BOB, 1, ALICE);
|
||||
Create(g);
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
|
||||
false)); // in0 = 1.0
|
||||
TF_ASSERT_OK(Run(rendez_));
|
||||
Tensor out = V(-1);
|
||||
bool is_dead = false;
|
||||
TF_ASSERT_OK(
|
||||
rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
|
||||
EXPECT_EQ(1.0, V(out)); // out = 1.0
|
||||
EXPECT_FALSE(is_dead);
|
||||
}
|
||||
|
||||
TEST_F(ExecutorTest, SimpleSwitchDead) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
|
||||
auto in1 = test::graph::Constant(g, VB(true));
|
||||
auto tmp = test::graph::Switch(g, in0, in1);
|
||||
test::graph::Send(g, tmp, "c", BOB, 1, ALICE);
|
||||
Create(g);
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
|
||||
false)); // in0 = 1.0
|
||||
TF_ASSERT_OK(Run(rendez_));
|
||||
Tensor out = V(-1);
|
||||
bool is_dead = false;
|
||||
TF_ASSERT_OK(
|
||||
rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
|
||||
EXPECT_TRUE(is_dead);
|
||||
}
|
||||
|
||||
TEST_F(ExecutorTest, Abort) {
|
||||
// e = a + b + c + d
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
|
||||
auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB);
|
||||
auto in2 = test::graph::Recv(g, "c", "float", ALICE, 1, BOB);
|
||||
auto in3 = test::graph::Recv(g, "d", "float", ALICE, 1, BOB);
|
||||
auto add0 = test::graph::Add(g, in0, in1);
|
||||
auto add1 = test::graph::Add(g, in2, in3);
|
||||
auto add2 = test::graph::Add(g, add0, add1);
|
||||
test::graph::Send(g, add2, "e", BOB, 1, ALICE);
|
||||
Create(g);
|
||||
|
||||
// Needs 4 inputs (recv). One of them is aborted.
|
||||
rendez_->Ref();
|
||||
SchedClosure([this]() {
|
||||
Env::Default()->SleepForMicroseconds(100 * 1000);
|
||||
Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"),
|
||||
Rendezvous::Args(), V(1.0), false);
|
||||
rendez_->Unref();
|
||||
});
|
||||
rendez_->Ref();
|
||||
SchedClosure([this]() {
|
||||
Env::Default()->SleepForMicroseconds(100 * 1000);
|
||||
Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"),
|
||||
Rendezvous::Args(), V(1.0), false);
|
||||
rendez_->Unref();
|
||||
});
|
||||
rendez_->Ref();
|
||||
SchedClosure([this]() {
|
||||
Env::Default()->SleepForMicroseconds(100 * 1000);
|
||||
Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "c"),
|
||||
Rendezvous::Args(), V(1.0), false);
|
||||
rendez_->Unref();
|
||||
});
|
||||
rendez_->Ref();
|
||||
SchedClosure([this]() {
|
||||
Env::Default()->SleepForMicroseconds(100 * 1000);
|
||||
rendez_->StartAbort(errors::Aborted(""));
|
||||
rendez_->Unref();
|
||||
});
|
||||
EXPECT_TRUE(errors::IsAborted(Run(rendez_)));
|
||||
Tensor out = V(-1);
|
||||
bool is_dead = false;
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Recv(
|
||||
Key(BOB, kIncarnation, ALICE, "c"), Rendezvous::Args(), &out, &is_dead)));
|
||||
// At this point there can still be pending (albeit Aborted) Send
|
||||
// closures holding Refs on rendez_. We need to wait for them, or
|
||||
// else there can be a memory leak at termination.
|
||||
while (!rendez_->RefCountIsOne())
|
||||
;
|
||||
}
|
||||
|
||||
TEST_F(ExecutorTest, RecvInvalidDtype) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
// An input vector of type float of size 1.
|
||||
auto one = test::graph::Recv(g, "one", "float", ALICE, 1, BOB);
|
||||
// A floating point variable vector of size 1.
|
||||
auto var = test::graph::Var(g, DT_FLOAT, TensorShape({1}));
|
||||
// Initialize the variable with input.
|
||||
auto init = test::graph::Assign(g, var, one);
|
||||
// Output
|
||||
auto* two = test::graph::Send(g, var, "two", BOB, 1, ALICE);
|
||||
g->AddControlEdge(init, two); // Ensures run after init.
|
||||
Create(g);
|
||||
Rendezvous* rendez = NewLocalRendezvous();
|
||||
// Send a double instead of float.
|
||||
TF_ASSERT_OK(rendez->Send(Key(ALICE, 1, BOB, "one"), Rendezvous::Args(),
|
||||
VD(1.0), false));
|
||||
// Fails due to invalid dtype.
|
||||
EXPECT_TRUE(errors::IsInternal(Run(rendez)));
|
||||
Tensor output;
|
||||
bool is_dead;
|
||||
EXPECT_TRUE(errors::IsInternal(rendez->Recv(
|
||||
Key(BOB, 1, ALICE, "two"), Rendezvous::Args(), &output, &is_dead)));
|
||||
rendez->Unref();
|
||||
}
|
||||
|
||||
TEST_F(ExecutorTest, RecvInvalidRefDtype) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
// A var that always produces as invalid dtype.
|
||||
auto var = test::graph::InvalidRefType(g, DT_FLOAT, DT_DOUBLE);
|
||||
test::graph::Send(g, var, "out", BOB, 1, ALICE);
|
||||
Create(g);
|
||||
Rendezvous* rendez = NewLocalRendezvous();
|
||||
EXPECT_TRUE(errors::IsInternal(Run(rendez)));
|
||||
Tensor output;
|
||||
bool is_dead;
|
||||
EXPECT_TRUE(errors::IsInternal(rendez->Recv(
|
||||
Key(BOB, 1, ALICE, "out"), Rendezvous::Args(), &output, &is_dead)));
|
||||
rendez->Unref();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
368
tensorflow/core/distributed_runtime/graph_mgr.cc
Normal file
368
tensorflow/core/distributed_runtime/graph_mgr.cc
Normal file
@ -0,0 +1,368 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/config.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_partition.h"
|
||||
#include "tensorflow/core/graph/validate.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GraphMgr::GraphMgr(const WorkerEnv* worker_env)
|
||||
: worker_env_(worker_env), table_(5) {}
|
||||
|
||||
GraphMgr::~GraphMgr() {
|
||||
for (auto p : table_) p.second->Unref();
|
||||
}
|
||||
|
||||
GraphMgr::Item::~Item() {
|
||||
for (const auto& unit : this->units) {
|
||||
CHECK_NOTNULL(unit.device);
|
||||
delete unit.root;
|
||||
delete unit.lib;
|
||||
unit.device->op_segment()->RemoveHold(this->session);
|
||||
}
|
||||
delete this->lib_def;
|
||||
}
|
||||
|
||||
// NOTE: node->device_name() is not set by GraphConstructor. We
|
||||
// expects that NodeDef in GraphDef given to workers fully specifies
|
||||
// device names.
|
||||
static string SplitByDevice(const Node* node) {
|
||||
return node->assigned_device_name();
|
||||
}
|
||||
|
||||
// Validates "gdef" device specifications.
|
||||
static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
for (const auto& ndef : gdef.node()) {
|
||||
if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) {
|
||||
return errors::InvalidArgument("Missing device name in: ",
|
||||
SummarizeNodeDef(ndef));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates executors given a graph definition "gdef" of a "session".
|
||||
// If a node in "gdef" is shared by other graphs in "session", the
|
||||
// same op kernel is reused. E.g., typically a params node is shared
|
||||
// by multiple graphs in a session.
|
||||
//
|
||||
// If "gdef" is assigned to multiple devices, extra nodes (e.g.,
|
||||
// send/recv nodes) maybe added. The extra nodes' name are generated
|
||||
// by calling "new_name(old_name)".
|
||||
//
|
||||
// "executors" are filled with one executor per device if success and
|
||||
// the caller takes the ownership of returned executors.
|
||||
Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
||||
const GraphOptions& graph_options, Item* item) {
|
||||
item->session = session;
|
||||
item->lib_def = new FunctionLibraryDefinition(gdef.library());
|
||||
|
||||
TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
|
||||
|
||||
if (gdef.versions().producer() >= 5) {
|
||||
// Validate the graph: we assume that merging two valid graphs
|
||||
// should maintain graph validity.
|
||||
TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *item->lib_def));
|
||||
}
|
||||
|
||||
// Constructs the graph out of "gdef".
|
||||
Graph graph(item->lib_def);
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.expect_device_spec = true;
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
|
||||
|
||||
// Splits "graph" into multiple subgraphs by device names.
|
||||
std::unordered_map<string, GraphDef> partitions;
|
||||
PartitionOptions popts;
|
||||
popts.node_to_loc = SplitByDevice;
|
||||
popts.new_name = [this](const string& prefix) {
|
||||
mutex_lock l(mu_);
|
||||
return strings::StrCat(prefix, "_G", next_id_++);
|
||||
};
|
||||
popts.get_incarnation = [this](const string& name) {
|
||||
Device* device = nullptr;
|
||||
Status s = worker_env_->device_mgr->LookupDevice(name, &device);
|
||||
if (s.ok()) {
|
||||
return device->attributes().incarnation();
|
||||
} else {
|
||||
return PartitionOptions::kIllegalIncarnation;
|
||||
}
|
||||
};
|
||||
popts.control_flow_added = true;
|
||||
popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
|
||||
TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
|
||||
if (popts.scheduling_for_recvs) {
|
||||
TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
|
||||
}
|
||||
|
||||
thread::ThreadPool* pool = worker_env_->compute_pool;
|
||||
auto runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
|
||||
|
||||
LocalExecutorParams params;
|
||||
|
||||
Status s;
|
||||
item->units.reserve(partitions.size());
|
||||
const auto& optimizer_opts = graph_options.optimizer_options();
|
||||
GraphOptimizer optimizer(optimizer_opts);
|
||||
for (auto&& p : partitions) {
|
||||
const string& device_name = p.first;
|
||||
GraphDef* def = &p.second;
|
||||
item->units.resize(item->units.size() + 1);
|
||||
ExecutionUnit* unit = &(item->units.back());
|
||||
|
||||
// Find the device.
|
||||
s = worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
|
||||
if (!s.ok()) break;
|
||||
|
||||
// Construct the subgraph.
|
||||
Graph* subgraph = new Graph(item->lib_def);
|
||||
// Give the device an opportunity to rewrite its subgraph.
|
||||
unit->device->MaybeRewriteGraph(gdef.library(), def);
|
||||
s = ConvertGraphDefToGraph(opts, *def, subgraph);
|
||||
if (!s.ok()) {
|
||||
delete subgraph;
|
||||
break;
|
||||
}
|
||||
// Top-level nodes in the graph uses the op segment to cache
|
||||
// kernels. Therefore, as long as the executor is alive, we need
|
||||
// to ensure the kernels cached for the session are alive.
|
||||
auto opseg = unit->device->op_segment();
|
||||
opseg->AddHold(session);
|
||||
|
||||
// Function library runtime.
|
||||
unit->lib = NewFunctionLibraryRuntime(
|
||||
unit->device, runner, def->versions().producer(), item->lib_def,
|
||||
graph_options.optimizer_options());
|
||||
|
||||
// Construct the root executor for the subgraph.
|
||||
params.device = unit->device;
|
||||
auto lib = unit->lib;
|
||||
params.function_library = lib;
|
||||
params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
|
||||
OpKernel** kernel) {
|
||||
// Caches the kernel only if the node is stateful.
|
||||
if (!lib->IsStateful(ndef.op())) {
|
||||
return lib->CreateKernel(ndef, kernel);
|
||||
}
|
||||
auto create_fn = [lib, &ndef](OpKernel** kernel) {
|
||||
return lib->CreateKernel(ndef, kernel);
|
||||
};
|
||||
// Kernels created for subgraph nodes need to be cached. On
|
||||
// cache miss, create_fn() is invoked to create a kernel based
|
||||
// on the function library here + global op registry.
|
||||
return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
|
||||
};
|
||||
params.delete_kernel = [lib](OpKernel* kernel) {
|
||||
// If the node is stateful, opseg owns it. Otherwise, delete it.
|
||||
if (kernel && !lib->IsStateful(kernel->type_string())) {
|
||||
delete kernel;
|
||||
}
|
||||
};
|
||||
|
||||
optimizer.Optimize(lib, &subgraph);
|
||||
s = ValidateMemoryTypes(DeviceType(unit->device->device_type()), subgraph);
|
||||
if (!s.ok()) {
|
||||
delete subgraph;
|
||||
break;
|
||||
}
|
||||
s = NewLocalExecutor(params, subgraph, &unit->root);
|
||||
if (!s.ok()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status GraphMgr::Register(const string& session, const GraphDef& gdef,
|
||||
const GraphOptions& graph_options, string* handle) {
|
||||
Item* item = new Item;
|
||||
Status s = InitItem(session, gdef, graph_options, item);
|
||||
if (!s.ok()) {
|
||||
item->Unref();
|
||||
return s;
|
||||
}
|
||||
|
||||
// Inserts one item into table_.
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
*handle = strings::Printf("%016llx", ++next_id_);
|
||||
item->handle = *handle;
|
||||
CHECK(table_.insert({*handle, item}).second);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphMgr::Deregister(const string& handle) {
|
||||
Item* item = nullptr;
|
||||
// Removes one item from table_.
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
auto iter = table_.find(handle);
|
||||
if (iter == table_.end()) {
|
||||
return errors::Aborted("Graph handle is not found: ", handle,
|
||||
". Possibly, this worker just restarted.");
|
||||
}
|
||||
item = iter->second;
|
||||
table_.erase(iter);
|
||||
}
|
||||
item->Unref();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphMgr::DeregisterAll() {
|
||||
std::vector<Item*> items;
|
||||
// Removes all items from table_.
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
for (const auto& entry : table_) {
|
||||
items.push_back(entry.second);
|
||||
}
|
||||
table_.clear();
|
||||
}
|
||||
for (auto item : items) {
|
||||
item->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphMgr::Execute(const string& handle, const int64 step_id,
|
||||
const ExecutorOpts& opts,
|
||||
StepStatsCollector* collector,
|
||||
CancellationManager* cancellation_manager,
|
||||
const NamedTensors& in, NamedTensors* out) {
|
||||
Notification n;
|
||||
Status status;
|
||||
ExecuteAsync(handle, step_id, opts, collector, cancellation_manager, in, out,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return status;
|
||||
}
|
||||
|
||||
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||
const ExecutorOpts& opts,
|
||||
StepStatsCollector* collector,
|
||||
CancellationManager* cancellation_manager,
|
||||
const NamedTensors& in, NamedTensors* out,
|
||||
StatusCallback done) {
|
||||
// Lookup an item. Holds one ref while executing.
|
||||
Item* item = nullptr;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
auto iter = table_.find(handle);
|
||||
if (iter != table_.end()) {
|
||||
item = iter->second;
|
||||
item->Ref();
|
||||
}
|
||||
}
|
||||
|
||||
if (item == nullptr) {
|
||||
done(errors::Aborted("Graph handle is not found: ", handle));
|
||||
return;
|
||||
}
|
||||
|
||||
const int num_units = item->units.size();
|
||||
CHECK_GE(num_units, 1);
|
||||
|
||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||
|
||||
// Sends values specified by the caller.
|
||||
for (const auto& p : in) {
|
||||
const string& key = p.first;
|
||||
const Tensor& val = p.second;
|
||||
const Status s = rendezvous->Send(key, Rendezvous::Args(), val, false);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
item->Unref();
|
||||
rendezvous->Unref();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Starts parallel Executors.
|
||||
//
|
||||
// NOTE: Transfer one ref of rendezvous and one ref of item to
|
||||
// RunAllDone.
|
||||
ExecutorBarrier* barrier = new ExecutorBarrier(
|
||||
num_units, rendezvous, std::bind(&ME::RunAllDone, this, item, rendezvous,
|
||||
out, done, std::placeholders::_1));
|
||||
Executor::Args args;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
args.step_id = ++next_id_;
|
||||
}
|
||||
args.rendezvous = rendezvous;
|
||||
args.cancellation_manager = cancellation_manager;
|
||||
args.stats_collector = collector;
|
||||
VLOG(1) << "Step " << args.step_id << " is for handle " << handle
|
||||
<< ", graph-local step " << step_id;
|
||||
thread::ThreadPool* pool = worker_env_->compute_pool;
|
||||
args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
|
||||
for (const auto& unit : item->units) {
|
||||
unit.root->RunAsync(args, barrier->Get());
|
||||
}
|
||||
}
|
||||
|
||||
void GraphMgr::RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out,
|
||||
StatusCallback done, Status s) {
|
||||
if (s.ok()) {
|
||||
// Receives values requested by the caller.
|
||||
for (auto& p : *out) {
|
||||
const string& key = p.first;
|
||||
Tensor* val = &p.second;
|
||||
bool is_dead = false;
|
||||
s = rendezvous->Recv(key, Rendezvous::Args(), val, &is_dead);
|
||||
if (is_dead) {
|
||||
s = errors::InvalidArgument("The tensor returned for ", key,
|
||||
" was not valid.");
|
||||
}
|
||||
if (!s.ok()) break;
|
||||
}
|
||||
}
|
||||
done(s);
|
||||
rendezvous->Unref();
|
||||
item->Unref();
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
147
tensorflow/core/distributed_runtime/graph_mgr.h
Normal file
147
tensorflow/core/distributed_runtime/graph_mgr.h
Normal file
@ -0,0 +1,147 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/config.pb.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class ExecutorOpts;
|
||||
class StepStatsCollector;
|
||||
|
||||
// GraphMgr keeps track of a set of graphs that are registered with a
|
||||
// TensorFlow worker. Each registered graph is identified by a handle
|
||||
// that is generated by GraphMgr and returned to the caller.
|
||||
//
|
||||
// After a successful registration, the caller executes a graph using
|
||||
// the graph handle. Each execution is distinguished from others by a
|
||||
// caller generated global unique id "step_id". Multiple executions
|
||||
// can use the same graph concurrently and independently as long as
|
||||
// "step_id" used are different.
|
||||
//
|
||||
// Multiple threads can call GraphMgr methods concurrently.
|
||||
//
|
||||
// E.g.,
|
||||
// GraphMgr gmgr(worker_env);
|
||||
// string handle;
|
||||
// TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b },
|
||||
// &handle));
|
||||
// GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) },
|
||||
// { "b", Tensor({3, 4}) } };
|
||||
// GraphMgr::NamedTensors out = { { "c", Tensor() } };
|
||||
// TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out));
|
||||
// EXPECT_EQ(out["c"], Tensor({4, 6}));
|
||||
class GraphMgr {
|
||||
public:
|
||||
explicit GraphMgr(const WorkerEnv* worker_env);
|
||||
~GraphMgr();
|
||||
|
||||
// Registers a graph. Fills in "handle"
|
||||
Status Register(const string& session, const GraphDef& gdef,
|
||||
const GraphOptions& graph_options, string* handle);
|
||||
|
||||
// Executes one step of a registered graph "handle".
|
||||
//
|
||||
// If "out" is not nullptr, "out" specifies all keys the execution
|
||||
// should receive upon finish.
|
||||
typedef std::map<string, Tensor> NamedTensors;
|
||||
typedef std::function<void(const Status&)> StatusCallback;
|
||||
void ExecuteAsync(const string& handle, const int64 step_id,
|
||||
const ExecutorOpts& opts, StepStatsCollector* collector,
|
||||
CancellationManager* cancellation_manager,
|
||||
const NamedTensors& in, NamedTensors* out,
|
||||
StatusCallback done);
|
||||
|
||||
// Synchronous wrapper.
|
||||
Status Execute(const string& handle, const int64 step_id,
|
||||
const ExecutorOpts& opts,
|
||||
StepStatsCollector* step_stats_collector,
|
||||
CancellationManager* cancellation_manager,
|
||||
const NamedTensors& in, NamedTensors* out);
|
||||
|
||||
// Deregisters a graph.
|
||||
Status Deregister(const string& handle);
|
||||
|
||||
// Deregister all graphs.
|
||||
Status DeregisterAll();
|
||||
|
||||
private:
|
||||
typedef GraphMgr ME;
|
||||
|
||||
struct ExecutionUnit {
|
||||
Device* device = nullptr;
|
||||
Executor* root = nullptr;
|
||||
FunctionLibraryRuntime* lib = nullptr;
|
||||
};
|
||||
|
||||
struct Item : public core::RefCounted {
|
||||
// TOOD(zhifengc): Keeps a copy of the original graph if the need arises.
|
||||
// TOOD(zhifengc): Stats, updated by multiple runs potentially.
|
||||
// TOOD(zhifengc): Dup-detection. Ensure step_id only run once.
|
||||
~Item() override;
|
||||
|
||||
// Session handle.
|
||||
string session;
|
||||
|
||||
// Graph handle.
|
||||
string handle;
|
||||
|
||||
// The definition of the library is shared by all partitions.
|
||||
FunctionLibraryDefinition* lib_def = nullptr;
|
||||
|
||||
// A graph is partitioned over multiple devices. Each partition
|
||||
// has a root executor which may call into the runtime library.
|
||||
std::vector<ExecutionUnit> units;
|
||||
};
|
||||
|
||||
// Not owned.
|
||||
const WorkerEnv* worker_env_;
|
||||
|
||||
// Owned.
|
||||
mutex mu_;
|
||||
int64 next_id_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
// Table mapping graph handles to registered graphs.
|
||||
//
|
||||
// TODO(zhifengc): If the client does not call Deregister, we'll
|
||||
// lose memory over time. We should implement a timeout-based
|
||||
// mechanism to gc these graphs.
|
||||
std::unordered_map<string, Item*> table_;
|
||||
|
||||
void RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out,
|
||||
StatusCallback done, Status run_status);
|
||||
|
||||
Status InitItem(const string& session, const GraphDef& gdef,
|
||||
const GraphOptions& graph_options, Item* item);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
|
413
tensorflow/core/distributed_runtime/master.cc
Normal file
413
tensorflow/core/distributed_runtime/master.cc
Normal file
@ -0,0 +1,413 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Master implements the service MasterSerivce.
|
||||
//
|
||||
// A Master maintains the state of live graph computation
|
||||
// sessions, each session orchestrates both local and remote devices
|
||||
// to carry out the graph computation.
|
||||
//
|
||||
// A Master knows ahead of time local devices available as
|
||||
// client devices.
|
||||
//
|
||||
// A Master discovers remote devices on-demand and keeps track of
|
||||
// statistics of those remote devices.
|
||||
//
|
||||
// Each session analyses the graph, places nodes across available
|
||||
// devices, and ultimately drives the graph computation by initiating
|
||||
// RunGraph on the workers.
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/master.h"
|
||||
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Master::Master(MasterEnv* env, double session_gc_seconds)
|
||||
: env_(env),
|
||||
last_1000_steps_(1000),
|
||||
step_count_(0),
|
||||
session_gc_seconds_(session_gc_seconds) {
|
||||
// Right now, a master service must be co-located with a device.
|
||||
// Otherwise, fetches do not work.
|
||||
CHECK(!env->local_devices.empty());
|
||||
|
||||
if (session_gc_seconds_ > 0.0) {
|
||||
SchedClosure([this]() { GC(); });
|
||||
}
|
||||
}
|
||||
|
||||
Master::~Master() {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
shutdown_ = true;
|
||||
shutdown_cv_.notify_all();
|
||||
}
|
||||
gc_stopped_.WaitForNotification();
|
||||
}
|
||||
|
||||
void Master::GC() {
|
||||
Env* env = Env::Default();
|
||||
while (true) {
|
||||
mutex_lock l(mu_);
|
||||
const int kTimeoutMilliseconds = 10 * 1000; // 10 seconds.
|
||||
WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);
|
||||
if (shutdown_) {
|
||||
break;
|
||||
}
|
||||
std::vector<string> handles;
|
||||
const int64 num_micros = static_cast<int64>(session_gc_seconds_ * 1000000);
|
||||
for (const auto& entry : sessions_) {
|
||||
auto lat = entry.second->last_access_time_usec();
|
||||
if (env->NowMicros() - lat > num_micros) {
|
||||
handles.push_back(entry.first);
|
||||
auto* sess = entry.second;
|
||||
SchedClosure([this, sess]() {
|
||||
LOG(WARNING) << "GC session " << sess->handle() << " after "
|
||||
<< session_gc_seconds_ << " seconds. "
|
||||
<< "Note that if you are starting multiple replicas "
|
||||
<< "on a staggered delay, session_gc_seconds may need "
|
||||
<< "to be raised.";
|
||||
sess->Close();
|
||||
});
|
||||
}
|
||||
}
|
||||
for (const auto& handle : handles) sessions_.erase(handle);
|
||||
}
|
||||
gc_stopped_.Notify();
|
||||
}
|
||||
|
||||
class DeviceFinder {
|
||||
public:
|
||||
explicit DeviceFinder(
|
||||
const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env)
|
||||
: env_(env) {
|
||||
auto process_filter = [this](const string& filter) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
|
||||
filters_.push_back(parsed);
|
||||
} else {
|
||||
LOG(FATAL) << "Skipping invalid filter: " << filter;
|
||||
}
|
||||
};
|
||||
for (const string& filter : device_filters) {
|
||||
process_filter(filter);
|
||||
}
|
||||
}
|
||||
|
||||
~DeviceFinder() {
|
||||
for (Device* dev : found_) delete dev;
|
||||
}
|
||||
|
||||
void Start() {
|
||||
// Enumerates all known workers' target. A target name is a
|
||||
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
|
||||
std::vector<string> workers;
|
||||
env_->worker_cache->ListWorkers(&workers);
|
||||
std::vector<string> targets;
|
||||
if (filters_.empty()) {
|
||||
swap(workers, targets);
|
||||
} else {
|
||||
for (const string& name : workers) {
|
||||
if (MatchFilters(name)) {
|
||||
targets.push_back(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
num_pending_ = targets.size();
|
||||
if (num_pending_ == 0) {
|
||||
pending_zero_.notify_all();
|
||||
}
|
||||
}
|
||||
// Talk to all workers to get the list of available devices.
|
||||
using std::placeholders::_1;
|
||||
using std::placeholders::_2;
|
||||
for (size_t i = 0; i < targets.size(); ++i) {
|
||||
NewRemoteDevices(env_->env, env_->worker_cache, targets[i],
|
||||
std::bind(&ME::WhenFound, this, _1, _2));
|
||||
}
|
||||
}
|
||||
|
||||
void Wait() {
|
||||
mutex_lock l(mu_);
|
||||
while (num_pending_ != 0) {
|
||||
pending_zero_.wait(l);
|
||||
}
|
||||
}
|
||||
|
||||
// The caller takes the ownership of returned remote devices.
|
||||
void GetRemoteDevices(const std::vector<Device*>& local,
|
||||
std::vector<Device*>* remote) {
|
||||
std::unordered_set<string> names(local.size());
|
||||
for (Device* dev : local) names.insert(dev->name());
|
||||
mutex_lock l(mu_);
|
||||
for (Device* dev : found_) {
|
||||
const string& name = dev->name();
|
||||
if (names.insert(name).second && MatchFilters(name)) {
|
||||
remote->push_back(dev);
|
||||
} else {
|
||||
delete dev;
|
||||
}
|
||||
}
|
||||
found_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
typedef DeviceFinder ME;
|
||||
const MasterEnv* env_;
|
||||
std::vector<DeviceNameUtils::ParsedName> filters_;
|
||||
|
||||
mutex mu_;
|
||||
int num_pending_ GUARDED_BY(mu_);
|
||||
condition_variable pending_zero_;
|
||||
std::vector<Device*> found_ GUARDED_BY(mu_);
|
||||
|
||||
void WhenFound(const Status& s, std::vector<Device*>* devices) {
|
||||
mutex_lock l(mu_);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Master init: " << s;
|
||||
} else {
|
||||
found_.insert(found_.end(), devices->begin(), devices->end());
|
||||
devices->clear();
|
||||
}
|
||||
--num_pending_;
|
||||
if (num_pending_ == 0) {
|
||||
pending_zero_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true iff the set of devices allowed by 'x' intersects
|
||||
// with the set of devices allowed by 'y'.
|
||||
bool Intersects(const DeviceNameUtils::ParsedName& x,
|
||||
const DeviceNameUtils::ParsedName& y) {
|
||||
return (!x.has_job || !y.has_job || x.job == y.job) &&
|
||||
(!x.has_replica || !y.has_replica || x.replica == y.replica) &&
|
||||
(!x.has_task || !y.has_task || x.task == y.task) &&
|
||||
(!x.has_type || !y.has_type || x.type == y.type) &&
|
||||
(!x.has_id || !y.has_id || x.id == y.id);
|
||||
}
|
||||
|
||||
// Returns true iff 'name' matches one of the filters_.
|
||||
bool MatchFilters(const string& name) {
|
||||
if (filters_.empty()) return true;
|
||||
DeviceNameUtils::ParsedName x;
|
||||
if (DeviceNameUtils::ParseFullName(name, &x)) {
|
||||
for (const auto& filter : filters_) {
|
||||
if (Intersects(x, filter)) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
|
||||
};
|
||||
|
||||
void Master::CreateSession(const CreateSessionRequest* req,
|
||||
CreateSessionResponse* resp, MyClosure done) {
|
||||
SchedClosure([this, req, resp, done]() {
|
||||
Status status = ValidateExternalGraphDefSyntax(req->graph_def());
|
||||
if (status.ok()) {
|
||||
// Ping all the workers and build the list of devices that the
|
||||
// session will use.
|
||||
DeviceFinder finder(req->config().device_filters(), env_);
|
||||
finder.Start();
|
||||
finder.Wait();
|
||||
std::vector<Device*> remote_devices;
|
||||
finder.GetRemoteDevices(env_->local_devices, &remote_devices);
|
||||
SessionOptions options;
|
||||
options.config = req->config();
|
||||
MasterSessionInterface* session =
|
||||
env_->master_session_factory(options, env_, &remote_devices);
|
||||
GraphDef* gdef =
|
||||
const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
|
||||
Status create_status = session->Create(gdef);
|
||||
if (!create_status.ok()) {
|
||||
done(create_status);
|
||||
return;
|
||||
}
|
||||
resp->set_session_handle(session->handle());
|
||||
// Insert into the session map.
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
CHECK(sessions_.insert({session->handle(), session}).second);
|
||||
}
|
||||
}
|
||||
done(status);
|
||||
});
|
||||
}
|
||||
|
||||
void Master::ExtendSession(const ExtendSessionRequest* req,
|
||||
ExtendSessionResponse* resp, MyClosure done) {
|
||||
mu_.lock();
|
||||
MasterSessionInterface* session = nullptr;
|
||||
session = gtl::FindPtrOrNull(sessions_, req->session_handle());
|
||||
if (session == nullptr) {
|
||||
mu_.unlock();
|
||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||
return;
|
||||
}
|
||||
|
||||
SchedClosure([session, req, resp, done]() {
|
||||
Status status = ValidateExternalGraphDefSyntax(req->graph_def());
|
||||
if (status.ok()) {
|
||||
status = session->Extend(req, resp);
|
||||
}
|
||||
done(status);
|
||||
});
|
||||
mu_.unlock();
|
||||
}
|
||||
|
||||
void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp, MyClosure done) {
|
||||
mu_.lock();
|
||||
uint64 start_time = env_->env->NowMicros();
|
||||
MasterSessionInterface* session =
|
||||
gtl::FindPtrOrNull(sessions_, req->session_handle());
|
||||
if (session == nullptr) {
|
||||
mu_.unlock();
|
||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||
return;
|
||||
}
|
||||
|
||||
SchedClosure([this, start_time, session, opts, req, resp, done]() {
|
||||
Status status = session->Run(opts, req, resp);
|
||||
uint64 done_time = env_->env->NowMicros();
|
||||
done(status);
|
||||
mutex_lock l(mu_);
|
||||
last_1000_steps_.AddValue((done_time - start_time) / 1e9);
|
||||
++step_count_;
|
||||
});
|
||||
mu_.unlock();
|
||||
}
|
||||
|
||||
void Master::CloseSession(const CloseSessionRequest* req,
|
||||
CloseSessionResponse* resp, MyClosure done) {
|
||||
MasterSessionInterface* session = nullptr;
|
||||
{
|
||||
mu_.lock();
|
||||
auto iter = sessions_.find(req->session_handle());
|
||||
if (iter == sessions_.end()) {
|
||||
mu_.unlock();
|
||||
done(errors::Aborted(
|
||||
"Session ", req->session_handle(),
|
||||
" is not found. Possibly, this master has restarted."));
|
||||
return;
|
||||
}
|
||||
session = iter->second;
|
||||
sessions_.erase(iter);
|
||||
mu_.unlock();
|
||||
}
|
||||
|
||||
// Session Close() blocks on thread shutdown. Therefore, we need to
|
||||
// delete it in non-critical thread.
|
||||
SchedClosure([session, done]() {
|
||||
Status s = session->Close();
|
||||
done(s);
|
||||
});
|
||||
}
|
||||
|
||||
void Master::ListDevices(const ListDevicesRequest* req,
|
||||
ListDevicesResponse* resp, MyClosure done) {
|
||||
SchedClosure([this, req, resp, done]() {
|
||||
DeviceFinder finder({}, env_);
|
||||
finder.Start();
|
||||
finder.Wait();
|
||||
std::vector<Device*> remote_devices;
|
||||
finder.GetRemoteDevices(env_->local_devices, &remote_devices);
|
||||
for (Device* dev : env_->local_devices) {
|
||||
*(resp->add_local_device()) = dev->attributes();
|
||||
}
|
||||
for (Device* dev : remote_devices) {
|
||||
*(resp->add_remote_device()) = dev->attributes();
|
||||
delete dev;
|
||||
}
|
||||
done(Status::OK());
|
||||
});
|
||||
}
|
||||
|
||||
void Master::CleanupWorkers(const ResetRequest& reset) {
|
||||
std::vector<string> worker_names;
|
||||
env_->worker_cache->ListWorkers(&worker_names);
|
||||
if (!worker_names.empty()) {
|
||||
const int num_workers = worker_names.size();
|
||||
std::vector<Notification> n(num_workers);
|
||||
CleanupAllRequest req;
|
||||
(*req.mutable_container()) = reset.container();
|
||||
std::vector<CleanupAllResponse> resp(num_workers);
|
||||
int c = 0;
|
||||
for (int i = 0; i < num_workers; ++i) {
|
||||
auto worker = env_->worker_cache->CreateWorker(worker_names[i]);
|
||||
if (worker) {
|
||||
worker->CleanupAllAsync(&req, &resp[i], [&n, worker, c](Status s) {
|
||||
TF_CHECK_OK(s);
|
||||
delete worker;
|
||||
n[c].Notify();
|
||||
});
|
||||
} else {
|
||||
n[c].Notify();
|
||||
}
|
||||
++c;
|
||||
}
|
||||
for (int i = 0; i < n.size(); ++i) {
|
||||
n[i].WaitForNotification();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Master::Reset(const ResetRequest* req, ResetResponse* resp,
|
||||
MyClosure done) {
|
||||
// Vector to hold the session pointers present in the sessions_
|
||||
// (string->Session*) map.
|
||||
std::vector<MasterSessionInterface*> sessions;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
for (const auto& entry : sessions_) {
|
||||
sessions.push_back(entry.second);
|
||||
}
|
||||
sessions_.clear();
|
||||
}
|
||||
|
||||
CleanupWorkers(*req);
|
||||
|
||||
SchedClosure([sessions, done]() {
|
||||
Status s;
|
||||
for (MasterSessionInterface* session : sessions) {
|
||||
s.Update(session->Close());
|
||||
}
|
||||
done(s);
|
||||
});
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
98
tensorflow/core/distributed_runtime/master.h
Normal file
98
tensorflow/core/distributed_runtime/master.h
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_session_interface.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Master {
|
||||
public:
|
||||
explicit Master(MasterEnv* env, double session_gc_seconds);
|
||||
virtual ~Master();
|
||||
|
||||
// Convenient typedef for a closure passing a Status.
|
||||
typedef std::function<void(const Status&)> MyClosure;
|
||||
|
||||
void CreateSession(const CreateSessionRequest* req,
|
||||
CreateSessionResponse* resp, MyClosure done);
|
||||
|
||||
void ExtendSession(const ExtendSessionRequest* req,
|
||||
ExtendSessionResponse* resp, MyClosure done);
|
||||
|
||||
void RunStep(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp, MyClosure done);
|
||||
|
||||
void CloseSession(const CloseSessionRequest* req, CloseSessionResponse* resp,
|
||||
MyClosure done);
|
||||
|
||||
void ListDevices(const ListDevicesRequest* req, ListDevicesResponse* resp,
|
||||
MyClosure done);
|
||||
|
||||
void Reset(const ResetRequest* req, ResetResponse* resp, MyClosure done);
|
||||
|
||||
private:
|
||||
typedef Master ME;
|
||||
|
||||
// Not owned.
|
||||
MasterEnv* env_ = nullptr;
|
||||
|
||||
// Owned.
|
||||
mutex mu_;
|
||||
|
||||
// shutdown_ is set to true by the dtor.
|
||||
condition_variable shutdown_cv_;
|
||||
bool shutdown_ GUARDED_BY(mu_) = false;
|
||||
Notification gc_stopped_;
|
||||
|
||||
// Maps session handles to sessions.
|
||||
std::unordered_map<string, MasterSessionInterface*> sessions_ GUARDED_BY(mu_);
|
||||
|
||||
// Moving average of step times.
|
||||
MovingAverage last_1000_steps_ GUARDED_BY(mu_);
|
||||
|
||||
// Cumulative number of steps executed.
|
||||
int64 step_count_ GUARDED_BY(mu_);
|
||||
|
||||
// If a session is not active for this many seconds, it will be
|
||||
// closed automatically.
|
||||
const double session_gc_seconds_;
|
||||
|
||||
// Call CleanupAll on all workers.
|
||||
void CleanupWorkers(const ResetRequest& reset);
|
||||
|
||||
// Cleanup unused session.
|
||||
void GC();
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Master);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
|
66
tensorflow/core/distributed_runtime/master_env.h
Normal file
66
tensorflow/core/distributed_runtime/master_env.h
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
class Env;
|
||||
class MasterSessionInterface;
|
||||
class OpRegistryInterface;
|
||||
class WorkerCacheInterface;
|
||||
|
||||
// The master environment class, which holds a bag of pointers to
|
||||
// per-master state.
|
||||
//
|
||||
// MasterEnv does not own its member pointers.
|
||||
struct MasterEnv {
|
||||
Env* env = nullptr;
|
||||
|
||||
// Object from which WorkerInterface instances can be obtained.
|
||||
WorkerCacheInterface* worker_cache = nullptr;
|
||||
|
||||
// The operation definitions to use. Must be filled before use.
|
||||
const OpRegistryInterface* ops = nullptr;
|
||||
|
||||
// Local devices co-located with this master. Devices are not owned
|
||||
// by the master service.
|
||||
//
|
||||
// REQUIRES: !local_devices.empty().
|
||||
std::vector<Device*> local_devices;
|
||||
|
||||
// Factory for creating master sessions, given session options and a
|
||||
// vector of devices.
|
||||
//
|
||||
// The caller of the function takes ownership of the returned
|
||||
// `MasterSessionInterface`, which may not be null. Ownership of the
|
||||
// `MasterEnv*` is retained by the caller. The callee takes
|
||||
// ownership of the `std::vector<Device*>*` argument, but does not
|
||||
// take ownership of the `Device*` objects in the vector.
|
||||
std::function<MasterSessionInterface*(const SessionOptions&, MasterEnv*,
|
||||
std::vector<Device*>*)>
|
||||
master_session_factory;
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
|
52
tensorflow/core/distributed_runtime/master_interface.h
Normal file
52
tensorflow/core/distributed_runtime/master_interface.h
Normal file
@ -0,0 +1,52 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Pure virtual interface for communicating with the TensorFlow Master service.
|
||||
//
|
||||
// This interface is intended to support in-process master
|
||||
// implementations that do not require an RPC roundtrip.
|
||||
class MasterInterface {
|
||||
public:
|
||||
virtual ~MasterInterface() {}
|
||||
virtual Status CreateSession(const CreateSessionRequest* request,
|
||||
CreateSessionResponse* response) = 0;
|
||||
|
||||
virtual Status ExtendSession(const ExtendSessionRequest* request,
|
||||
ExtendSessionResponse* response) = 0;
|
||||
|
||||
virtual Status RunStep(const RunStepRequest* request,
|
||||
RunStepResponse* response) = 0;
|
||||
|
||||
virtual Status CloseSession(const CloseSessionRequest* request,
|
||||
CloseSessionResponse* response) = 0;
|
||||
|
||||
virtual Status ListDevices(const ListDevicesRequest* request,
|
||||
ListDevicesResponse* response) = 0;
|
||||
|
||||
virtual Status Reset(const ResetRequest* request,
|
||||
ResetResponse* response) = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
|
942
tensorflow/core/distributed_runtime/master_session.cc
Normal file
942
tensorflow/core/distributed_runtime/master_session.cc
Normal file
@ -0,0 +1,942 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/master_session.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_session_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/simple_graph_execution_state.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph_partition.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// A little bit of per-step state.
|
||||
struct PerStepState {
|
||||
Microseconds start_micros = Microseconds(0);
|
||||
Microseconds end_micros = Microseconds(0);
|
||||
std::vector<StepStats> step_stats; // per partition
|
||||
};
|
||||
|
||||
// A session encapsulates a graph computation (resource allocation,
|
||||
// placement, execution, etc.).
|
||||
class MasterSession : public MasterSessionInterface {
|
||||
public:
|
||||
// This session encapsulates the graph computation for a graph.
|
||||
//
|
||||
// The session places nodes on devices in "remote_devs" and executes
|
||||
// operations on these devices.
|
||||
//
|
||||
// The caller takes ownership of all remote devices.
|
||||
MasterSession(const SessionOptions& options, const MasterEnv* env,
|
||||
std::vector<Device*>* remote_devs);
|
||||
|
||||
// Initialize the Session for "def". Must be called before Extend(),
|
||||
// Run(), or Close().
|
||||
//
|
||||
// The callee may clear "def".
|
||||
Status Create(GraphDef* def) override;
|
||||
|
||||
// Returns the session handle.
|
||||
const string& handle() const override { return handle_; }
|
||||
|
||||
// Returns the last access time (the number of micro-seconds since
|
||||
// some fixed point in time) of this session.
|
||||
uint64 last_access_time_usec() const override {
|
||||
return last_access_time_usec_.load();
|
||||
}
|
||||
|
||||
// Attempt to extend the graph according to the given "req".
|
||||
// (See master.proto for details of valid extensions.)
|
||||
//
|
||||
// PRECONDITION: The current version of this session's graph
|
||||
// is "req->current_graph_version".
|
||||
//
|
||||
// POSTCONDITION: The current version of this session's graph
|
||||
// is "resp->new_graph_version".
|
||||
//
|
||||
// Extend() may block the caller thread for a long time.
|
||||
Status Extend(const ExtendSessionRequest* req,
|
||||
ExtendSessionResponse* resp) override;
|
||||
|
||||
// Run one step.
|
||||
Status Run(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp) override;
|
||||
|
||||
// Close this session and delete "*this". Returns OK if all known
|
||||
// states are cleanup successfully.
|
||||
//
|
||||
// Close() may block the caller thread for a long time.
|
||||
Status Close() override;
|
||||
|
||||
private:
|
||||
SessionOptions session_opts_;
|
||||
|
||||
// Not owned.
|
||||
const MasterEnv* env_;
|
||||
|
||||
// The opaque session handle.
|
||||
const string handle_;
|
||||
|
||||
// Owned.
|
||||
std::vector<Device*> remote_devs_;
|
||||
|
||||
// The device set used by this session.
|
||||
DeviceSet devices_;
|
||||
|
||||
// TODO(zhifengc): Support Extend().
|
||||
//
|
||||
// 'func_def_lib_' is a copy of the initial graph def's library.
|
||||
// 'flib_def_' is an index structure of "func_def_lib_' keyed by
|
||||
// function names.
|
||||
FunctionDefLibrary func_def_lib_;
|
||||
FunctionLibraryDefinition* flib_def_ = nullptr;
|
||||
|
||||
std::atomic_ulong last_access_time_usec_;
|
||||
|
||||
mutex mu_;
|
||||
std::unique_ptr<SimpleGraphExecutionState> execution_state_;
|
||||
int64 graph_version_;
|
||||
|
||||
int32 steps_since_last_scheduling_ GUARDED_BY(mu_) = 0;
|
||||
int32 scheduling_period_steps_ GUARDED_BY(mu_) = 10;
|
||||
|
||||
// We keep a map from a signature of a run request to the
|
||||
// ReffedClientGraph the can execute it. We keep up to one old copy
|
||||
// of each ReffedClientGraph around because if it gets deallocated
|
||||
// before a new substitute has been created, Variables can go out of
|
||||
// scope and lose their state.
|
||||
class ReffedClientGraph;
|
||||
typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
|
||||
RCGMap runs_ GUARDED_BY(mu_);
|
||||
RCGMap obsolete_ GUARDED_BY(mu_);
|
||||
|
||||
// Active RunStep calls.
|
||||
condition_variable num_running_is_zero_;
|
||||
int32 num_running_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
std::unordered_map<uint64, int64> subgraph_execution_counts_ GUARDED_BY(mu_);
|
||||
|
||||
// We need to ensure that certain nodes added (e.g., send and recv
|
||||
// nodes) are unique across all sub-graphs within this session.
|
||||
int64 next_node_id_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
// Private dtor. The client must call Close().
|
||||
virtual ~MasterSession();
|
||||
|
||||
Status StartStep(const RunStepRequest& req, BuildGraphOptions* opts,
|
||||
int64* count, ReffedClientGraph** graph);
|
||||
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
|
||||
RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status DoRunWithLocalExecution(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp);
|
||||
void UpdateLastAccessTime();
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
|
||||
};
|
||||
|
||||
// Session wraps ClientGraph in a reference counted object. This way,
|
||||
// Session can clear up the cache mapping Run requests to compiled
|
||||
// graphs while the compiled graph is still being used.
|
||||
//
|
||||
// TODO(zhifengc): Cleanup this class. It's becoming messy.
|
||||
class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
public:
|
||||
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
|
||||
ClientGraph* cg, const GraphOptions& graph_opts)
|
||||
: session_handle_(handle),
|
||||
client_graph_(cg),
|
||||
bopts_(bopts),
|
||||
graph_opts_(graph_opts) {
|
||||
VLOG(1) << "Created ReffedClientGraph for node with "
|
||||
<< client_graph_->graph.num_node_ids();
|
||||
|
||||
const string key =
|
||||
strings::StrCat("{", str_util::Join(bopts.feed_endpoints, ","), "},{",
|
||||
str_util::Join(bopts.target_nodes, ","), "},{",
|
||||
str_util::Join(bopts.fetch_endpoints, ","), "}");
|
||||
// TODO(mrry): Publish information about the graph (such as
|
||||
// timelines, the pruned graph, statistics, etc.).
|
||||
}
|
||||
|
||||
~ReffedClientGraph() override {
|
||||
delete client_graph_;
|
||||
DeregisterPartitions();
|
||||
}
|
||||
|
||||
const ClientGraph* client_graph() { return client_graph_; }
|
||||
|
||||
// Local execution methods.
|
||||
|
||||
// Partitions the graph into subgraphs and registers them on
|
||||
// workers.
|
||||
Status RegisterPartitions(const MasterEnv* env, const PartitionOptions& popts,
|
||||
const FunctionDefLibrary& func_def_lib);
|
||||
|
||||
// Runs one step of all partitions.
|
||||
Status RunPartitions(const MasterEnv* env, int64 step_id,
|
||||
int64 execution_count,
|
||||
SimpleGraphExecutionState* execution_state,
|
||||
PerStepState* pss, CallOptions* opts,
|
||||
const RunStepRequest& req, RunStepResponse* resp);
|
||||
|
||||
// Calls workers to cleanup states for the step "step_id". Waits
|
||||
// till all cleanup rpcs complete.
|
||||
Status CleanupPartitions(int64 step_id);
|
||||
|
||||
// TODO(mrry): Runtime statistics collection.
|
||||
|
||||
private:
|
||||
const string session_handle_;
|
||||
ClientGraph* const client_graph_ = nullptr;
|
||||
std::unordered_set<const Node*> nodes_needing_input_mapping_;
|
||||
BuildGraphOptions bopts_;
|
||||
const GraphOptions graph_opts_;
|
||||
|
||||
// Graph partitioned into per-location subgraphs.
|
||||
struct Part {
|
||||
// Worker name.
|
||||
string name;
|
||||
|
||||
// Graph definition.
|
||||
GraphDef gdef;
|
||||
|
||||
// Maps feed names to rendezvous keys. Empty most of the time.
|
||||
std::unordered_map<string, string> feed_key;
|
||||
|
||||
// Maps rendezvous keys to fetch names. Empty most of the time.
|
||||
std::unordered_map<string, string> key_fetch;
|
||||
|
||||
// The interface to the worker. Owned.
|
||||
WorkerInterface* worker = nullptr;
|
||||
|
||||
// After registeration with the worker, graph_handle identifies
|
||||
// this partition on the worker.
|
||||
string graph_handle;
|
||||
|
||||
Part() : feed_key(3), key_fetch(3) {}
|
||||
};
|
||||
|
||||
// partitions_ is immutable after RegisterPartitions() call
|
||||
// finishes. RunPartitions() can access partitions_ safely without
|
||||
// acquring locks.
|
||||
std::vector<Part> partitions_;
|
||||
|
||||
mutable mutex mu_;
|
||||
|
||||
// Partition initialization and registration only needs to happen
|
||||
// once. init_started_ && !init_done_ indicates the initialization
|
||||
// is on going.
|
||||
bool init_started_ GUARDED_BY(mu_) = false;
|
||||
Notification init_done_;
|
||||
|
||||
// init_result_ remembers the initialization error if any.
|
||||
Status init_result_ GUARDED_BY(mu_);
|
||||
|
||||
// Send/Recv nodes that are the result of client-added
|
||||
// feeds and fetches must be tracked so that the tensors
|
||||
// can be be added to the local rendezvous.
|
||||
static void TrackFeedsAndFetches(Part* part, const PartitionOptions& popts);
|
||||
|
||||
// The actual graph partitioning and registration implementation.
|
||||
Status DoRegisterPartitions(const MasterEnv* env,
|
||||
const PartitionOptions& popts,
|
||||
const FunctionDefLibrary& func_def_lib);
|
||||
|
||||
// Deregisters the partitions on the workers. Called in the
|
||||
// destructor and does not wait for the rpc completion.
|
||||
void DeregisterPartitions();
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
|
||||
};
|
||||
|
||||
Status MasterSession::ReffedClientGraph::RegisterPartitions(
|
||||
const MasterEnv* env, const PartitionOptions& popts,
|
||||
const FunctionDefLibrary& func_def_lib) {
|
||||
{ // Ensure register once.
|
||||
mu_.lock();
|
||||
if (!init_started_) {
|
||||
init_started_ = true;
|
||||
mu_.unlock();
|
||||
Status s = DoRegisterPartitions(env, popts, func_def_lib);
|
||||
mu_.lock();
|
||||
init_result_ = s;
|
||||
init_done_.Notify();
|
||||
} else {
|
||||
mu_.unlock();
|
||||
init_done_.WaitForNotification();
|
||||
mu_.lock();
|
||||
}
|
||||
Status result = init_result_;
|
||||
mu_.unlock();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
static string SplitByWorker(const Node* node) {
|
||||
string task;
|
||||
string device;
|
||||
CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
|
||||
&device))
|
||||
<< "node: " << node->name() << " dev: " << node->assigned_device_name();
|
||||
return task;
|
||||
}
|
||||
|
||||
void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
|
||||
Part* part, const PartitionOptions& popts) {
|
||||
for (int i = 0; i < part->gdef.node_size(); ++i) {
|
||||
NodeDef* ndef = part->gdef.mutable_node(i);
|
||||
const bool is_recv = ndef->op() == "_Recv";
|
||||
const bool is_send = ndef->op() == "_Send";
|
||||
|
||||
if (is_recv || is_send) {
|
||||
string name;
|
||||
TF_CHECK_OK(GetNodeAttr(*ndef, "tensor_name", &name));
|
||||
string send_device;
|
||||
TF_CHECK_OK(GetNodeAttr(*ndef, "send_device", &send_device));
|
||||
string recv_device;
|
||||
TF_CHECK_OK(GetNodeAttr(*ndef, "recv_device", &recv_device));
|
||||
uint64 send_device_incarnation;
|
||||
TF_CHECK_OK(
|
||||
GetNodeAttr(*ndef, "send_device_incarnation",
|
||||
reinterpret_cast<int64*>(&send_device_incarnation)));
|
||||
const string& key =
|
||||
Rendezvous::CreateKey(send_device, send_device_incarnation,
|
||||
recv_device, name, FrameAndIter(0, 0));
|
||||
|
||||
// Only send/recv nodes that were added as feeds and fetches
|
||||
// (client-terminated) should be tracked. Other send/recv nodes
|
||||
// are for transferring data between partitions / memory spaces.
|
||||
bool client_terminated;
|
||||
TF_CHECK_OK(GetNodeAttr(*ndef, "client_terminated", &client_terminated));
|
||||
if (client_terminated) {
|
||||
if (is_recv) {
|
||||
part->feed_key.insert({name, key});
|
||||
} else {
|
||||
part->key_fetch.insert({key, name});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
|
||||
const MasterEnv* env, const PartitionOptions& popts,
|
||||
const FunctionDefLibrary& func_def_lib) {
|
||||
// Partition the graph.
|
||||
Status s;
|
||||
std::unordered_map<string, GraphDef> graph_partitions;
|
||||
s = Partition(popts, &client_graph_->graph, &graph_partitions);
|
||||
if (!s.ok()) return s;
|
||||
partitions_.reserve(graph_partitions.size());
|
||||
for (auto& name_def : graph_partitions) {
|
||||
partitions_.resize(partitions_.size() + 1);
|
||||
Part* part = &partitions_.back();
|
||||
part->name = name_def.first;
|
||||
part->gdef.Swap(&name_def.second);
|
||||
// For simplicity, we ship the library completely to every worker.
|
||||
*(part->gdef.mutable_library()) = func_def_lib;
|
||||
TrackFeedsAndFetches(part, popts);
|
||||
part->worker = env->worker_cache->CreateWorker(part->name);
|
||||
if (part->worker == nullptr) {
|
||||
s = errors::NotFound("worker ", part->name);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!s.ok()) {
|
||||
for (Part& part : partitions_) {
|
||||
delete part.worker;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
struct Call {
|
||||
RegisterGraphRequest req;
|
||||
RegisterGraphResponse resp;
|
||||
Status status;
|
||||
Notification done;
|
||||
};
|
||||
const int num = partitions_.size();
|
||||
gtl::InlinedVector<Call, 4> calls(num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
const Part& part = partitions_[i];
|
||||
Call* c = &calls[i];
|
||||
c->req.set_session_handle(session_handle_);
|
||||
*c->req.mutable_graph_def() = part.gdef;
|
||||
*c->req.mutable_graph_options() = graph_opts_;
|
||||
VLOG(2) << "Register " << part.gdef.DebugString();
|
||||
auto cb = [c](const Status& s) {
|
||||
c->status = s;
|
||||
c->done.Notify();
|
||||
};
|
||||
part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
|
||||
}
|
||||
for (int i = num - 1; i >= 0; --i) {
|
||||
Call* c = &calls[i];
|
||||
c->done.WaitForNotification();
|
||||
s.Update(c->status);
|
||||
partitions_[i].graph_handle = c->resp.graph_handle();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
static bool CopyIfNeeded(TensorProto* in, TensorProto* out) {
|
||||
if (in->tensor_content().empty()) {
|
||||
// If the tensor is not encoded in tensor_content or contains 0
|
||||
// elements, we can return it to the client directly.
|
||||
out->Swap(in);
|
||||
} else {
|
||||
Tensor t(in->dtype());
|
||||
if (!t.FromProto(cpu_allocator(), *in)) return false;
|
||||
t.AsProtoField(out);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Helper class to manage "num" parallel RunGraph calls.
|
||||
class RunManyGraphs {
|
||||
public:
|
||||
explicit RunManyGraphs(int num) : calls_(num), num_pending_(num) {}
|
||||
|
||||
~RunManyGraphs() {}
|
||||
|
||||
// Returns the index-th call.
|
||||
struct Call {
|
||||
CallOptions opts;
|
||||
RunGraphRequest req;
|
||||
RunGraphResponse resp;
|
||||
};
|
||||
Call* get(int index) { return &calls_[index]; }
|
||||
|
||||
// When the index-th call is done, updates the overall status.
|
||||
void WhenDone(int index, const Status& s) {
|
||||
TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!s.ok()) {
|
||||
UpdateStatusLocked(s);
|
||||
}
|
||||
--num_pending_;
|
||||
cv_pending_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
void StartCancel() {
|
||||
mutex_lock l(mu_);
|
||||
UpdateStatusLocked(errors::Cancelled("RunManyGraphs"));
|
||||
}
|
||||
|
||||
void Wait() {
|
||||
mutex_lock l(mu_);
|
||||
while (num_pending_ > 0) {
|
||||
cv_pending_.wait(l);
|
||||
}
|
||||
}
|
||||
|
||||
Status status() const {
|
||||
mutex_lock l(mu_);
|
||||
return status_;
|
||||
}
|
||||
|
||||
private:
|
||||
gtl::InlinedVector<Call, 4> calls_;
|
||||
|
||||
// TODO(jeff,sanjay): Replace bookkeeping state here with a
|
||||
// BlockingCounter abstraction that we define in
|
||||
// tensorflow/core/lib/core.
|
||||
mutable mutex mu_;
|
||||
condition_variable cv_pending_;
|
||||
int num_pending_;
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
|
||||
void UpdateStatusLocked(const Status& s) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (status_.ok()) {
|
||||
status_ = s;
|
||||
for (Call& call : calls_) {
|
||||
call.opts.StartCancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
|
||||
};
|
||||
|
||||
Status MasterSession::ReffedClientGraph::RunPartitions(
|
||||
const MasterEnv* env, int64 step_id, int64 execution_count,
|
||||
SimpleGraphExecutionState* execution_state, PerStepState* pss,
|
||||
CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp) {
|
||||
VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
|
||||
<< execution_count;
|
||||
// Builds an index for feeds provided by the client.
|
||||
std::unordered_map<StringPiece, const TensorProto*, StringPiece::Hasher>
|
||||
feeds(3);
|
||||
|
||||
for (const auto& feed : req.feed()) {
|
||||
if (!feeds.insert({feed.name(), &feed.tensor()}).second) {
|
||||
return errors::InvalidArgument("Duplicated feeds: ", feed.name());
|
||||
}
|
||||
}
|
||||
|
||||
// Prepares a number of calls to workers. One call per partition.
|
||||
ExecutorOpts exec_opts;
|
||||
const int num = partitions_.size();
|
||||
RunManyGraphs calls(num);
|
||||
|
||||
for (int i = 0; i < num; ++i) {
|
||||
const Part& part = partitions_[i];
|
||||
RunManyGraphs::Call* c = calls.get(i);
|
||||
c->req.set_graph_handle(part.graph_handle);
|
||||
c->req.set_step_id(step_id);
|
||||
*c->req.mutable_exec_opts() = exec_opts;
|
||||
// If any feeds are provided, send the feed values together
|
||||
// in the RunGraph request.
|
||||
for (const auto& feed_key : part.feed_key) {
|
||||
const string& feed = feed_key.first;
|
||||
const string& key = feed_key.second;
|
||||
const TensorProto* val = feeds[feed];
|
||||
if (val == nullptr) {
|
||||
return errors::InvalidArgument("No feed is provided for feed=", feed,
|
||||
", key=", key);
|
||||
}
|
||||
auto* send = c->req.add_send();
|
||||
send->set_key(key);
|
||||
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
|
||||
}
|
||||
for (const auto& key_fetch : part.key_fetch) {
|
||||
const string& key = key_fetch.first;
|
||||
c->req.add_recv_key(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Issues RunGraph calls.
|
||||
for (int i = 0; i < num; ++i) {
|
||||
const Part& part = partitions_[i];
|
||||
RunManyGraphs::Call* call = calls.get(i);
|
||||
TRACEPRINTF("Partition %d %s", i, part.name.c_str());
|
||||
part.worker->RunGraphAsync(
|
||||
&call->opts, &call->req, &call->resp,
|
||||
std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
|
||||
}
|
||||
|
||||
// Waits for the RunGraph calls.
|
||||
call_opts->SetCancelCallback([&calls]() { calls.StartCancel(); });
|
||||
calls.Wait();
|
||||
call_opts->ClearCancelCallback();
|
||||
|
||||
// Collects fetches.
|
||||
Status status = calls.status();
|
||||
if (status.ok()) {
|
||||
for (int i = 0; i < num; ++i) {
|
||||
const Part& part = partitions_[i];
|
||||
for (auto& recv : *(calls.get(i)->resp.mutable_recv())) {
|
||||
auto* ret = resp->add_tensor();
|
||||
auto iter = part.key_fetch.find(recv.key());
|
||||
if (iter == part.key_fetch.end()) {
|
||||
status.Update(errors::Internal("Unexpected fetch key: ", recv.key()));
|
||||
break;
|
||||
}
|
||||
const string& fetch = iter->second;
|
||||
ret->set_name(fetch);
|
||||
if (!CopyIfNeeded(recv.mutable_val(), ret->mutable_tensor())) {
|
||||
status.Update(
|
||||
errors::Internal("Unexpected unparseable tensor: ", recv.key()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (calls.get(i)->resp.has_step_stats()) {
|
||||
pss->step_stats[i].Swap(calls.get(i)->resp.mutable_step_stats());
|
||||
}
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status MasterSession::ReffedClientGraph::CleanupPartitions(int64 step_id) {
|
||||
struct Call {
|
||||
CleanupGraphRequest req;
|
||||
CleanupGraphResponse resp;
|
||||
Notification done;
|
||||
Status status;
|
||||
};
|
||||
const int num = partitions_.size();
|
||||
gtl::InlinedVector<Call, 4> calls(num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
const Part& part = partitions_[i];
|
||||
Call* c = &calls[i];
|
||||
c->req.set_step_id(step_id);
|
||||
part.worker->CleanupGraphAsync(&c->req, &c->resp, [c](const Status& s) {
|
||||
c->status = s;
|
||||
c->done.Notify();
|
||||
});
|
||||
}
|
||||
Status s;
|
||||
for (int i = num - 1; i >= 0; --i) {
|
||||
Call* c = &calls[i];
|
||||
c->done.WaitForNotification();
|
||||
s.Update(c->status);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
// Makes async calls to workers without waiting deregistering subgraphs.
|
||||
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
|
||||
struct Call {
|
||||
DeregisterGraphRequest req;
|
||||
DeregisterGraphResponse resp;
|
||||
};
|
||||
for (Part& part : partitions_) {
|
||||
Call* c = new Call;
|
||||
c->req.set_graph_handle(part.graph_handle);
|
||||
WorkerInterface* w = part.worker;
|
||||
auto cb = [c, w](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
// This error is potentially benign, so we don't log at the
|
||||
// error level.
|
||||
LOG(INFO) << "DeregisterGraph error: " << s;
|
||||
}
|
||||
delete c;
|
||||
delete w;
|
||||
};
|
||||
w->DeregisterGraphAsync(&c->req, &c->resp, cb);
|
||||
}
|
||||
}
|
||||
|
||||
void BuildBuildGraphOptions(const RunStepRequest& req,
|
||||
BuildGraphOptions* opts) {
|
||||
for (const auto& feed : req.feed()) {
|
||||
opts->feed_endpoints.push_back(feed.name());
|
||||
}
|
||||
for (const auto& fetch : req.fetch()) {
|
||||
// TODO(touts): handle ref:
|
||||
opts->fetch_endpoints.push_back(fetch);
|
||||
}
|
||||
for (const auto& target : req.target()) {
|
||||
opts->target_nodes.push_back(target);
|
||||
}
|
||||
|
||||
std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
|
||||
std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
|
||||
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
|
||||
}
|
||||
|
||||
uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
|
||||
uint64 h = 0x2b992ddfa23249d6ull;
|
||||
for (const string& name : opts.feed_endpoints) {
|
||||
h = Hash64(name.c_str(), name.size(), h);
|
||||
}
|
||||
for (const string& name : opts.target_nodes) {
|
||||
h = Hash64(name.c_str(), name.size(), h);
|
||||
}
|
||||
for (const string& name : opts.fetch_endpoints) {
|
||||
h = Hash64(name.c_str(), name.size(), h);
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
string BuildGraphOptionsString(const BuildGraphOptions& opts) {
|
||||
string buf;
|
||||
for (const string& name : opts.feed_endpoints) {
|
||||
strings::StrAppend(&buf, " FdE: ", name);
|
||||
}
|
||||
strings::StrAppend(&buf, "\n");
|
||||
for (const string& name : opts.target_nodes) {
|
||||
strings::StrAppend(&buf, " TN: ", name);
|
||||
}
|
||||
strings::StrAppend(&buf, "\n");
|
||||
for (const string& name : opts.fetch_endpoints) {
|
||||
strings::StrAppend(&buf, " FeE: ", name);
|
||||
}
|
||||
strings::StrAppend(&buf, "\n");
|
||||
return buf;
|
||||
}
|
||||
|
||||
MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
|
||||
std::vector<Device*>* remote_devs)
|
||||
: session_opts_(opt),
|
||||
env_(env),
|
||||
handle_(strings::FpToString(random::New64())),
|
||||
graph_version_(0),
|
||||
runs_(5) {
|
||||
UpdateLastAccessTime();
|
||||
|
||||
swap(remote_devs_, *remote_devs);
|
||||
VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
|
||||
<< " #remote " << remote_devs_.size();
|
||||
for (Device* d : remote_devs_) {
|
||||
devices_.AddDevice(d);
|
||||
}
|
||||
int num_local_devices = 0;
|
||||
for (Device* d : env->local_devices) {
|
||||
devices_.AddDevice(d);
|
||||
if (num_local_devices == 0) {
|
||||
// Uses the first local device as the client device.
|
||||
devices_.set_client_device(d);
|
||||
}
|
||||
num_local_devices++;
|
||||
}
|
||||
}
|
||||
|
||||
MasterSession::~MasterSession() {
|
||||
for (const auto& iter : runs_) iter.second->Unref();
|
||||
for (const auto& iter : obsolete_) iter.second->Unref();
|
||||
delete flib_def_;
|
||||
for (Device* dev : remote_devs_) delete dev;
|
||||
}
|
||||
|
||||
void MasterSession::UpdateLastAccessTime() {
|
||||
last_access_time_usec_.store(Env::Default()->NowMicros());
|
||||
}
|
||||
|
||||
Status MasterSession::Create(GraphDef* graph_def) {
|
||||
// Keeps a copy of graph_def->library() and flib_def_ serves the
|
||||
// OpRegistryInterface used by the SimpleGraphExecutionState to construct the
|
||||
// pre-partitioned graphs during DoRunWithLocalExecution().
|
||||
func_def_lib_.Swap(graph_def->mutable_library());
|
||||
flib_def_ = new FunctionLibraryDefinition(func_def_lib_);
|
||||
|
||||
SimpleGraphExecutionStateOptions options;
|
||||
options.device_set = &devices_;
|
||||
options.session_options = &session_opts_;
|
||||
execution_state_.reset(new SimpleGraphExecutionState(flib_def_, options));
|
||||
TF_RETURN_IF_ERROR(execution_state_->Create(graph_def));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MasterSession::Extend(const ExtendSessionRequest* req,
|
||||
ExtendSessionResponse* resp) {
|
||||
UpdateLastAccessTime();
|
||||
std::unique_ptr<SimpleGraphExecutionState> old_execution_state;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
// TODO(mrry): Redesign the locking with reader/writer locks to prevent
|
||||
// starvation due to concurrent steps being issued. This is not
|
||||
// immediately important because we expect Extend to be used in
|
||||
// development/interactive exploration, and not during high-throughput
|
||||
// training.
|
||||
while (num_running_ != 0) {
|
||||
num_running_is_zero_.wait(l);
|
||||
}
|
||||
|
||||
if (graph_version_ != req->current_graph_version()) {
|
||||
return errors::Aborted("Current version is ", graph_version_,
|
||||
" but caller expected ",
|
||||
req->current_graph_version(), ".");
|
||||
}
|
||||
|
||||
CHECK(execution_state_);
|
||||
SimpleGraphExecutionState* extended_execution_state = nullptr;
|
||||
Status s =
|
||||
execution_state_->Extend(req->graph_def(), &extended_execution_state);
|
||||
if (s.ok()) {
|
||||
CHECK(extended_execution_state);
|
||||
old_execution_state =
|
||||
std::move(execution_state_); // Will be released outside the lock
|
||||
execution_state_.reset(extended_execution_state);
|
||||
++graph_version_;
|
||||
resp->set_new_graph_version(graph_version_);
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
Status MasterSession::StartStep(const RunStepRequest& req,
|
||||
BuildGraphOptions* opts, int64* count,
|
||||
ReffedClientGraph** rcg) {
|
||||
BuildBuildGraphOptions(req, opts);
|
||||
const uint64 hash = HashBuildGraphOptions(*opts);
|
||||
ReffedClientGraph* to_unref = nullptr;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
// Keep track of how many times this subgraph has been executed in
|
||||
// this session.
|
||||
int64* c = &subgraph_execution_counts_[hash];
|
||||
*count = (*c)++;
|
||||
auto iter = runs_.find(hash);
|
||||
if (iter == runs_.end()) {
|
||||
// We have not seen this subgraph before. Build the subgraph and
|
||||
// cache it.
|
||||
VLOG(1) << "Unseen hash " << hash << " for "
|
||||
<< BuildGraphOptionsString(*opts);
|
||||
ClientGraph* client_graph = nullptr;
|
||||
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(*opts, &client_graph));
|
||||
auto entry = new ReffedClientGraph(handle_, *opts, client_graph,
|
||||
session_opts_.config.graph_options());
|
||||
iter = runs_.insert({hash, entry}).first;
|
||||
auto obs_iter = obsolete_.find(hash);
|
||||
if (obs_iter != obsolete_.end()) {
|
||||
to_unref = obs_iter->second;
|
||||
obsolete_.erase(obs_iter);
|
||||
}
|
||||
VLOG(1) << "Preparing to execute new graph";
|
||||
}
|
||||
*rcg = iter->second;
|
||||
(*rcg)->Ref();
|
||||
}
|
||||
if (to_unref) to_unref->Unref();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
|
||||
RCGMap* rcg_map) {
|
||||
VLOG(1) << "Discarding all reffed graphs";
|
||||
for (auto p : *rcg_map) {
|
||||
ReffedClientGraph* rcg = p.second;
|
||||
if (to_unref) {
|
||||
to_unref->push_back(rcg);
|
||||
} else {
|
||||
rcg->Unref();
|
||||
}
|
||||
}
|
||||
rcg_map->clear();
|
||||
}
|
||||
|
||||
Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp) {
|
||||
UpdateLastAccessTime();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
++num_running_;
|
||||
}
|
||||
Status status = DoRunWithLocalExecution(opts, req, resp);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
--num_running_;
|
||||
if (num_running_ == 0) {
|
||||
num_running_is_zero_.notify_all();
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
|
||||
const RunStepRequest* req,
|
||||
RunStepResponse* resp) {
|
||||
VLOG(2) << "DoRunWithLocalExecution "
|
||||
<< "req: " << req->DebugString();
|
||||
PerStepState pss;
|
||||
pss.start_micros = Env::Default()->NowMicros();
|
||||
|
||||
// Prepare.
|
||||
BuildGraphOptions bgopts;
|
||||
ReffedClientGraph* rcg = nullptr;
|
||||
int64 count = 0;
|
||||
TF_RETURN_IF_ERROR(StartStep(*req, &bgopts, &count, &rcg));
|
||||
|
||||
// Unref "rcg" when out of scope.
|
||||
core::ScopedUnref unref(rcg);
|
||||
|
||||
// Registers subgraphs if haven't done so.
|
||||
PartitionOptions popts;
|
||||
popts.node_to_loc = SplitByWorker;
|
||||
popts.new_name = [this](const string& prefix) {
|
||||
mutex_lock l(mu_);
|
||||
return strings::StrCat(prefix, "_S", next_node_id_++);
|
||||
};
|
||||
popts.get_incarnation = [this](const string& name) {
|
||||
Device* d = devices_.FindDeviceByName(name);
|
||||
if (d == nullptr) {
|
||||
return PartitionOptions::kIllegalIncarnation;
|
||||
} else {
|
||||
return d->attributes().incarnation();
|
||||
}
|
||||
};
|
||||
popts.control_flow_added = false;
|
||||
// TODO(mrry): Enable DT_BFLOAT16 casting.
|
||||
// TODO(mrry): Enable recv scheduling.
|
||||
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(env_, popts, func_def_lib_));
|
||||
|
||||
// Keeps the highest 8 bits 0x01: we reserve some bits of the
|
||||
// step_id for future use.
|
||||
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
|
||||
TRACEPRINTF("stepid %llu", step_id);
|
||||
|
||||
TF_RETURN_IF_ERROR(rcg->RunPartitions(
|
||||
env_, step_id, count, execution_state_.get(), &pss, opts, *req, resp));
|
||||
|
||||
pss.end_micros = Env::Default()->NowMicros();
|
||||
|
||||
// Schedule post-processing and cleanup to be done async.
|
||||
rcg->Ref();
|
||||
// TODO(tucker): We're doing the stats processing prior to returning
|
||||
// the response to the client. Ensure it's safe to do so, then schedule
|
||||
// in a closure.
|
||||
SchedClosure([this, rcg, step_id]() {
|
||||
Status s = rcg->CleanupPartitions(step_id);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Cleanup partition error: " << s;
|
||||
}
|
||||
rcg->Unref();
|
||||
});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MasterSession::Close() {
|
||||
std::vector<ReffedClientGraph*> to_unref;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
while (num_running_ != 0) {
|
||||
num_running_is_zero_.wait(l);
|
||||
}
|
||||
ClearRunsTable(&to_unref, &runs_);
|
||||
ClearRunsTable(&to_unref, &obsolete_);
|
||||
}
|
||||
for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
|
||||
delete this;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
|
||||
namespace internal {
|
||||
|
||||
MasterSessionInterface* NewMasterSession(const SessionOptions& options,
|
||||
const MasterEnv* env,
|
||||
std::vector<Device*>* remote_devs) {
|
||||
return new MasterSession(options, env, remote_devs);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace tensorflow
|
38
tensorflow/core/distributed_runtime/master_session.h
Normal file
38
tensorflow/core/distributed_runtime/master_session.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
class MasterEnv;
|
||||
class MasterSessionInterface;
|
||||
|
||||
namespace internal {
|
||||
|
||||
MasterSessionInterface* NewMasterSession(const SessionOptions& options,
|
||||
const MasterEnv* env,
|
||||
std::vector<Device*>* remote_devs);
|
||||
|
||||
} // namespace internal
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
class ThreadPool;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class CallOptions;
|
||||
class GraphDef;
|
||||
class RunStepRequest;
|
||||
class RunStepResponse;
|
||||
class ExtendSessionRequest;
|
||||
class ExtendSessionResponse;
|
||||
|
||||
// A "master session" encapsulates a distributed graph computation
|
||||
// (resource allocation, placement, execution, etc.).
|
||||
class MasterSessionInterface {
|
||||
public:
|
||||
// Initializes the Session with "def". Must be called before Extend(),
|
||||
// Run(), or Close().
|
||||
//
|
||||
// The callee may clear "def".
|
||||
virtual Status Create(GraphDef* def) = 0;
|
||||
|
||||
// Returns the session handle.
|
||||
virtual const string& handle() const = 0;
|
||||
|
||||
// Returns the last access time (the number of micro-seconds since
|
||||
// some fixed point in time) of this session.
|
||||
virtual uint64 last_access_time_usec() const = 0;
|
||||
|
||||
// Attempt to extend the graph according to the given "req".
|
||||
// (See master.proto for details of valid extensions.)
|
||||
//
|
||||
// PRECONDITION: The current version of this session's graph
|
||||
// is "req->current_version".
|
||||
//
|
||||
// POSTCONDITION: The current version of this session's graph
|
||||
// is "req->new_version".
|
||||
//
|
||||
// Extend() may block the caller thread for a long time.
|
||||
virtual Status Extend(const ExtendSessionRequest* req,
|
||||
ExtendSessionResponse* resp) = 0;
|
||||
|
||||
// Run one step.
|
||||
virtual Status Run(CallOptions* opts, const RunStepRequest* req,
|
||||
RunStepResponse* resp) = 0;
|
||||
|
||||
// Close this session and delete "*this". Returns OK if all known
|
||||
// states are cleanup successfully.
|
||||
//
|
||||
// Close() may block the caller thread for a long time.
|
||||
virtual Status Close() = 0;
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_
|
423
tensorflow/core/distributed_runtime/master_test.cc
Normal file
423
tensorflow/core/distributed_runtime/master_test.cc
Normal file
@ -0,0 +1,423 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/master.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "external/grpc/include/grpc++/grpc++.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class MasterTest : public ::testing::Test {
|
||||
protected:
|
||||
MasterTest() {
|
||||
std::vector<string> targets;
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 1;
|
||||
(*options.config.mutable_device_count())["GPU"] = 0;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 2, &cluster_));
|
||||
master_ = grpc::MasterService::NewStub(
|
||||
NewHostPortGrpcChannel(cluster_->targets()[0]));
|
||||
}
|
||||
|
||||
std::unique_ptr<test::TestCluster> cluster_;
|
||||
std::unique_ptr<grpc::MasterService::Stub> master_;
|
||||
|
||||
// Helpers for MasterService.{CreateSession,RunStep,CloseSession}
|
||||
// rpc calls.
|
||||
|
||||
Status CreateSession(const GraphDef& def, string* handle,
|
||||
int64* initial_version) {
|
||||
::grpc::ClientContext ctx;
|
||||
CreateSessionRequest req;
|
||||
*(req.mutable_graph_def()) = def;
|
||||
// Invokes placement frequently.
|
||||
req.mutable_config()->set_placement_period(1);
|
||||
CreateSessionResponse resp;
|
||||
const Status s = FromGrpcStatus(master_->CreateSession(&ctx, req, &resp));
|
||||
if (s.ok()) {
|
||||
*handle = resp.session_handle();
|
||||
*initial_version = resp.graph_version();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status ExtendSession(const string& handle, const GraphDef& def,
|
||||
int64 current_version, int64* new_version) {
|
||||
::grpc::ClientContext ctx;
|
||||
ExtendSessionRequest req;
|
||||
req.set_session_handle(handle);
|
||||
*(req.mutable_graph_def()) = def;
|
||||
req.set_current_graph_version(current_version);
|
||||
ExtendSessionResponse resp;
|
||||
const Status s = FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp));
|
||||
if (s.ok()) {
|
||||
*new_version = resp.new_graph_version();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status RunStep(const string& handle,
|
||||
const std::vector<std::pair<string, const Tensor*> >& feed,
|
||||
const std::map<string, Tensor*>& fetch) {
|
||||
::grpc::ClientContext ctx;
|
||||
RunStepRequest req;
|
||||
req.set_session_handle(handle);
|
||||
for (const auto& p : feed) {
|
||||
const string& feed_name = p.first;
|
||||
const Tensor* feed_tensor = p.second;
|
||||
auto f = req.add_feed();
|
||||
f->set_name(feed_name);
|
||||
feed_tensor->AsProtoTensorContent(f->mutable_tensor());
|
||||
}
|
||||
for (const auto& p : fetch) {
|
||||
const string& fetch_name = p.first;
|
||||
req.add_fetch(fetch_name);
|
||||
}
|
||||
RunStepResponse resp;
|
||||
const Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp));
|
||||
if (s.ok()) {
|
||||
for (const auto& fetch_resp : resp.tensor()) {
|
||||
auto it = fetch.find(fetch_resp.name());
|
||||
CHECK(it != fetch.end());
|
||||
CHECK(it->second->FromProto(fetch_resp.tensor()));
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status CloseSession(const string& handle) {
|
||||
::grpc::ClientContext ctx;
|
||||
CloseSessionRequest req;
|
||||
req.set_session_handle(handle);
|
||||
CloseSessionResponse resp;
|
||||
return FromGrpcStatus(master_->CloseSession(&ctx, req, &resp));
|
||||
}
|
||||
|
||||
Status Reset() {
|
||||
::grpc::ClientContext ctx;
|
||||
ResetRequest req;
|
||||
ResetResponse resp;
|
||||
return FromGrpcStatus(master_->Reset(&ctx, req, &resp));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MasterTest, CreateClose) {
|
||||
GraphDef def; // Empty.
|
||||
string handle;
|
||||
int64 initial_version;
|
||||
TF_ASSERT_OK(CreateSession(def, &handle, &initial_version));
|
||||
EXPECT_TRUE(errors::IsAborted(CloseSession("randombits")));
|
||||
EXPECT_TRUE(CloseSession(handle).ok());
|
||||
}
|
||||
|
||||
TEST_F(MasterTest, ListDevices) {
|
||||
::grpc::ClientContext ctx;
|
||||
ListDevicesRequest req;
|
||||
ListDevicesResponse resp;
|
||||
const Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp));
|
||||
TF_EXPECT_OK(s);
|
||||
EXPECT_EQ(1, resp.local_device_size());
|
||||
EXPECT_EQ("CPU", resp.local_device(0).device_type());
|
||||
}
|
||||
|
||||
TEST_F(MasterTest, Reset) {
|
||||
GraphDef def; // Empty.
|
||||
string s1, s2;
|
||||
int64 initial_version1, initial_version2;
|
||||
TF_ASSERT_OK(CreateSession(def, &s1, &initial_version1));
|
||||
TF_ASSERT_OK(CreateSession(def, &s2, &initial_version2));
|
||||
EXPECT_TRUE(Reset().ok());
|
||||
EXPECT_TRUE(errors::IsAborted(CloseSession(s1)));
|
||||
EXPECT_TRUE(errors::IsAborted(CloseSession(s2)));
|
||||
}
|
||||
|
||||
TEST_F(MasterTest, Extend) {
|
||||
GraphDef def_0; // Empty.
|
||||
string handle;
|
||||
int64 initial_version;
|
||||
TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
|
||||
|
||||
Tensor A_expected(DT_FLOAT, TensorShape({2, 2}));
|
||||
test::FillValues<float>(&A_expected, {3.0, 2.0, -1.0, 0.0});
|
||||
|
||||
Tensor x_expected(DT_FLOAT, TensorShape({2, 1}));
|
||||
test::FillValues<float>(&x_expected, {2.0, 2.0});
|
||||
|
||||
Graph graph_1(OpRegistry::Global());
|
||||
test::graph::Constant(&graph_1, A_expected, "A");
|
||||
GraphDef def_1;
|
||||
test::graph::ToGraphDef(&graph_1, &def_1);
|
||||
int64 version_1;
|
||||
TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
|
||||
EXPECT_GT(version_1, initial_version);
|
||||
Tensor A(DT_FLOAT, TensorShape({2, 2}));
|
||||
TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
|
||||
test::ExpectTensorEqual<float>(A, A_expected);
|
||||
|
||||
Graph graph_2(OpRegistry::Global());
|
||||
test::graph::Constant(&graph_2, x_expected, "x");
|
||||
GraphDef def_2;
|
||||
test::graph::ToGraphDef(&graph_2, &def_2);
|
||||
int64 version_2;
|
||||
EXPECT_TRUE(errors::IsAborted(
|
||||
ExtendSession("randombits", def_2, version_1, &version_2)));
|
||||
TF_ASSERT_OK(ExtendSession(handle, def_2, version_1, &version_2));
|
||||
EXPECT_GT(version_2, version_1);
|
||||
|
||||
Tensor x(DT_FLOAT, TensorShape({2, 1}));
|
||||
TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"x:0", &x}}));
|
||||
test::ExpectTensorEqual<float>(A, A_expected);
|
||||
test::ExpectTensorEqual<float>(x, x_expected);
|
||||
|
||||
TF_ASSERT_OK(CloseSession(handle));
|
||||
}
|
||||
|
||||
TEST_F(MasterTest, ExtendUpdateStatefulFails) {
|
||||
GraphDef def_0; // Empty.
|
||||
string handle;
|
||||
int64 initial_version;
|
||||
TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
|
||||
|
||||
Graph graph_1(OpRegistry::Global());
|
||||
test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
|
||||
GraphDef def_1;
|
||||
test::graph::ToGraphDef(&graph_1, &def_1);
|
||||
|
||||
int64 version_1, version_2;
|
||||
TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
|
||||
EXPECT_GT(version_1, initial_version);
|
||||
EXPECT_TRUE(errors::IsInvalidArgument(
|
||||
ExtendSession(handle, def_1, version_1, &version_2)));
|
||||
TF_ASSERT_OK(CloseSession(handle));
|
||||
}
|
||||
|
||||
TEST_F(MasterTest, ExtendTwiceFails) {
|
||||
GraphDef def_0; // Empty.
|
||||
string handle;
|
||||
int64 initial_version;
|
||||
TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
|
||||
|
||||
Graph graph_1(OpRegistry::Global());
|
||||
test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
|
||||
GraphDef def_1;
|
||||
test::graph::ToGraphDef(&graph_1, &def_1);
|
||||
|
||||
int64 version_1;
|
||||
TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
|
||||
EXPECT_GT(version_1, initial_version);
|
||||
EXPECT_TRUE(errors::IsAborted(
|
||||
ExtendSession(handle, def_1, initial_version, &version_1)));
|
||||
TF_ASSERT_OK(CloseSession(handle));
|
||||
}
|
||||
|
||||
TEST_F(MasterTest, ConcurrentExtendOnlyOneSucceeds) {
|
||||
GraphDef def_0; // Empty.
|
||||
string handle;
|
||||
int64 initial_version;
|
||||
TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
|
||||
|
||||
Graph graph_1(OpRegistry::Global());
|
||||
test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
|
||||
GraphDef def_1;
|
||||
test::graph::ToGraphDef(&graph_1, &def_1);
|
||||
|
||||
Notification n;
|
||||
mutex mu;
|
||||
int succeeded = 0;
|
||||
int failed = 0;
|
||||
auto extend_fn = [this, handle, def_1, initial_version, &n, &mu, &succeeded,
|
||||
&failed]() {
|
||||
n.WaitForNotification();
|
||||
int64 new_version;
|
||||
Status s = ExtendSession(handle, def_1, initial_version, &new_version);
|
||||
EXPECT_TRUE(s.ok() || errors::IsAborted(s));
|
||||
{
|
||||
mutex_lock l(mu);
|
||||
if (s.ok()) {
|
||||
++succeeded;
|
||||
} else {
|
||||
++failed;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Run 100 concurrent Extend calls and expect only one to succeed.
|
||||
{
|
||||
thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 100);
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
thread_pool.Schedule(extend_fn);
|
||||
}
|
||||
n.Notify();
|
||||
}
|
||||
|
||||
EXPECT_EQ(failed, 99);
|
||||
EXPECT_EQ(succeeded, 1);
|
||||
TF_ASSERT_OK(CloseSession(handle));
|
||||
}
|
||||
|
||||
TEST_F(MasterTest, ConcurrentExtendAndRun) {
|
||||
Graph graph_0(OpRegistry::Global());
|
||||
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
|
||||
test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
|
||||
test::graph::Constant(&graph_0, a_tensor, "A");
|
||||
GraphDef def_0;
|
||||
test::graph::ToGraphDef(&graph_0, &def_0);
|
||||
|
||||
string handle;
|
||||
int64 initial_version;
|
||||
TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
|
||||
|
||||
Graph graph_1(OpRegistry::Global());
|
||||
Tensor b_tensor(DT_FLOAT, TensorShape({2, 2}));
|
||||
test::FillValues<float>(&b_tensor, {1, 0, 0, 1});
|
||||
test::graph::Constant(&graph_1, b_tensor, "B");
|
||||
GraphDef def_1;
|
||||
test::graph::ToGraphDef(&graph_1, &def_1);
|
||||
|
||||
Notification extend_done;
|
||||
Notification extend_can_start;
|
||||
|
||||
auto get_a_fn = [this, handle, &extend_done]() {
|
||||
Tensor A(DT_FLOAT, TensorShape({2, 2}));
|
||||
while (!extend_done.HasBeenNotified()) {
|
||||
TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
|
||||
}
|
||||
// Run at least once after the Extend has completed.
|
||||
TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
|
||||
};
|
||||
|
||||
auto get_a_and_b_fn = [this, handle, &extend_done, &extend_can_start]() {
|
||||
Tensor A(DT_FLOAT, TensorShape({2, 2}));
|
||||
Tensor B(DT_FLOAT, TensorShape({2, 2}));
|
||||
|
||||
// Run at least once before the Extend has completed.
|
||||
EXPECT_TRUE(
|
||||
errors::IsNotFound(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}})));
|
||||
extend_can_start.Notify();
|
||||
|
||||
// Concurrent with the Extend, we will either fail (as above), or
|
||||
// succeed (as below).
|
||||
while (!extend_done.HasBeenNotified()) {
|
||||
Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}});
|
||||
EXPECT_TRUE(errors::IsNotFound(s) || s.ok());
|
||||
}
|
||||
|
||||
// Run at least once after the Extend has completed.
|
||||
TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}));
|
||||
};
|
||||
|
||||
auto extend_fn = [this, handle, def_1, initial_version, &extend_done,
|
||||
&extend_can_start]() {
|
||||
extend_can_start.WaitForNotification();
|
||||
int64 version_1;
|
||||
TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
|
||||
extend_done.Notify();
|
||||
};
|
||||
|
||||
{
|
||||
thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 3);
|
||||
thread_pool.Schedule(get_a_fn);
|
||||
thread_pool.Schedule(get_a_and_b_fn);
|
||||
thread_pool.Schedule(extend_fn);
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(CloseSession(handle));
|
||||
}
|
||||
|
||||
TEST_F(MasterTest, EigenProblem) {
|
||||
// A = [3 2; -1 0]; x = rand(2, 1);
|
||||
// for i=1:100; x = A * x; end
|
||||
// We'll try to compute the largest eigenvalue for A.
|
||||
Graph graph(OpRegistry::Global());
|
||||
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
|
||||
// Store rows [3, 2] and [-1, 0] in row major format.
|
||||
test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
|
||||
Node* a_node = test::graph::Constant(&graph, a_tensor);
|
||||
|
||||
// x is from the feed.
|
||||
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
|
||||
test::FillValues<float>(&x_tensor, {0, 0});
|
||||
Node* x_node = test::graph::Constant(&graph, x_tensor);
|
||||
|
||||
// y = A * x
|
||||
Node* y_node = test::graph::Matmul(&graph, a_node, x_node, false, false);
|
||||
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&graph, &def);
|
||||
|
||||
string handle;
|
||||
int64 initial_version;
|
||||
TF_CHECK_OK(CreateSession(def, &handle, &initial_version));
|
||||
|
||||
// Temps supporting the computation of the convergence condition.
|
||||
const Eigen::array<Eigen::DenseIndex, 1> sum_along_dim(0);
|
||||
const Eigen::array<Eigen::DenseIndex, 2> matrix_transpose({1, 0});
|
||||
Tensor x(DT_FLOAT, TensorShape({2, 1}));
|
||||
Tensor y(DT_FLOAT, TensorShape({2, 1}));
|
||||
Eigen::Tensor<float, 1, Eigen::RowMajor> y_square_sum;
|
||||
Eigen::Tensor<float, 2, Eigen::RowMajor> y_normalized(2, 1);
|
||||
y_normalized.setRandom();
|
||||
Eigen::Tensor<float, 1, Eigen::RowMajor> error_square_sum;
|
||||
float lambda;
|
||||
|
||||
// The computation loop.
|
||||
bool converged = false;
|
||||
while (!converged) {
|
||||
// Run one step of the graph.
|
||||
auto x_matrix = x.matrix<float>();
|
||||
x_matrix = y_normalized;
|
||||
TF_EXPECT_OK(
|
||||
RunStep(handle, {{x_node->name(), &x}}, {{y_node->name() + ":0", &y}}));
|
||||
auto y_matrix = y.matrix<float>();
|
||||
|
||||
// Client code computes the convergence condition.
|
||||
{
|
||||
lambda = y_matrix(0, 0) / x_matrix(0, 0);
|
||||
y_square_sum = y.matrix<float>().square().sum(sum_along_dim);
|
||||
const float norm = static_cast<float>(sqrt(y_square_sum(0)));
|
||||
y_normalized = y_matrix * (1 / norm);
|
||||
error_square_sum = (x_matrix - y_normalized).square().sum(sum_along_dim);
|
||||
VLOG(1) << "x = [" << x_matrix.shuffle(matrix_transpose) << "] y = ["
|
||||
<< y_matrix.shuffle(matrix_transpose) << "] lambda = " << lambda;
|
||||
converged = sqrt(error_square_sum(0)) < 1e-10;
|
||||
}
|
||||
}
|
||||
EXPECT_NEAR(lambda, 2.0, 0.01);
|
||||
TF_EXPECT_OK(CloseSession(handle));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
69
tensorflow/core/distributed_runtime/process_util.cc
Normal file
69
tensorflow/core/distributed_runtime/process_util.cc
Normal file
@ -0,0 +1,69 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
static thread::ThreadPool* InitComputePool(const SessionOptions& options) {
|
||||
int32 inter_op_parallelism_threads =
|
||||
options.config.inter_op_parallelism_threads();
|
||||
if (inter_op_parallelism_threads == 0) {
|
||||
// Default to using the number of cores available in the process.
|
||||
inter_op_parallelism_threads = port::NumSchedulableCPUs();
|
||||
}
|
||||
|
||||
return new thread::ThreadPool(Env::Default(), "Compute",
|
||||
inter_op_parallelism_threads);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
thread::ThreadPool* ComputePool(const SessionOptions& options) {
|
||||
static thread::ThreadPool* compute_pool = InitComputePool(options);
|
||||
return compute_pool;
|
||||
}
|
||||
|
||||
void SchedClosure(std::function<void()> closure) {
|
||||
if (port::Tracing::IsActive()) {
|
||||
const uint64 id = port::Tracing::UniqueId();
|
||||
port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure,
|
||||
id);
|
||||
std::function<void()> wrapper = [closure, id]() {
|
||||
port::Tracing::ScopedActivity region(
|
||||
port::Tracing::EventCategory::kRunClosure, id);
|
||||
closure();
|
||||
};
|
||||
Env::Default()->SchedClosure(wrapper);
|
||||
} else {
|
||||
Env::Default()->SchedClosure(closure);
|
||||
}
|
||||
}
|
||||
|
||||
void SchedNonBlockingClosureAfter(int micros, std::function<void()> closure) {
|
||||
Env::Default()->SchedClosureAfter(micros, closure);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
39
tensorflow/core/distributed_runtime/process_util.h
Normal file
39
tensorflow/core/distributed_runtime/process_util.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns a process-wide ThreadPool for scheduling compute operations
|
||||
// using 'options'. Caller does not take ownership over threadpool.
|
||||
thread::ThreadPool* ComputePool(const SessionOptions& options);
|
||||
|
||||
// Schedule "closure" in the default thread queue.
|
||||
void SchedClosure(std::function<void()> closure);
|
||||
|
||||
// Schedule "closure" after the given number of microseconds in the
|
||||
// fixed-size ThreadPool used for non-blocking compute tasks.
|
||||
void SchedNonBlockingClosureAfter(int micros, std::function<void()> closure);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_
|
91
tensorflow/core/distributed_runtime/remote_device.cc
Normal file
91
tensorflow/core/distributed_runtime/remote_device.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using std::placeholders::_1;
|
||||
|
||||
// TODO(zhifengc): We need to consolidate (full/partial) device name
|
||||
// parsing into one place.
|
||||
//
|
||||
// Parses and returns the local device part (e.g., cpu:0, gpu:4).
|
||||
string GetLocalDeviceName(StringPiece fullname) {
|
||||
auto pos = fullname.rfind('/');
|
||||
CHECK_NE(pos, StringPiece::npos);
|
||||
fullname.remove_prefix(pos + 1);
|
||||
return fullname.ToString();
|
||||
}
|
||||
|
||||
class RemoteDevice : public Device {
|
||||
public:
|
||||
RemoteDevice(Env* env, const DeviceAttributes& da, WorkerInterface* wi)
|
||||
: Device(env, da, nullptr),
|
||||
local_dev_name_(GetLocalDeviceName(da.name())),
|
||||
wi_(wi) {}
|
||||
|
||||
~RemoteDevice() override { delete wi_; }
|
||||
Status Sync() override { return Status::OK(); }
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
|
||||
|
||||
private:
|
||||
const string local_dev_name_;
|
||||
WorkerInterface* wi_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RemoteDevice);
|
||||
};
|
||||
|
||||
void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
|
||||
const string& worker_name, NewRemoteDevicesDone done) {
|
||||
WorkerInterface* wi = worker_cache->CreateWorker(worker_name);
|
||||
if (wi == nullptr) {
|
||||
std::vector<Device*> empty;
|
||||
done(errors::NotFound("Device ", worker_name, " is not found."), &empty);
|
||||
return;
|
||||
}
|
||||
struct Call {
|
||||
GetStatusRequest req;
|
||||
GetStatusResponse resp;
|
||||
};
|
||||
Call* call = new Call;
|
||||
auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) {
|
||||
std::vector<Device*> remote_devices;
|
||||
if (s.ok()) {
|
||||
remote_devices.reserve(call->resp.device_attributes_size());
|
||||
for (const DeviceAttributes& da : call->resp.device_attributes()) {
|
||||
auto d =
|
||||
new RemoteDevice(env, da, worker_cache->CreateWorker(worker_name));
|
||||
remote_devices.push_back(d);
|
||||
}
|
||||
}
|
||||
done(s, &remote_devices);
|
||||
delete wi;
|
||||
delete call;
|
||||
};
|
||||
wi->GetStatusAsync(&call->req, &call->resp, cb);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
48
tensorflow/core/distributed_runtime/remote_device.h
Normal file
48
tensorflow/core/distributed_runtime/remote_device.h
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class Device;
|
||||
class Env;
|
||||
class WorkerCacheInterface;
|
||||
|
||||
// NewRemoteDevices discovers available devices on the
|
||||
// 'remote_worker'. The implementation uses 'channel_cache' to
|
||||
// discover how to communicate with the 'remote_worker' (via gRPC, for
|
||||
// example).
|
||||
//
|
||||
// NewRemoteDevices does not block.
|
||||
//
|
||||
// On success, the 'done' callback is given the OK status and a vector
|
||||
// of Device*. The caller should take ownership of these devices.
|
||||
//
|
||||
// Otherwise, the 'done' callback is given an error status and the
|
||||
// vector is empty.
|
||||
typedef std::function<void(const Status&, std::vector<Device*>*)>
|
||||
NewRemoteDevicesDone;
|
||||
void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
|
||||
const string& remote_worker, NewRemoteDevicesDone done);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_
|
89
tensorflow/core/distributed_runtime/remote_device_test.cc
Normal file
89
tensorflow/core/distributed_runtime/remote_device_test.cc
Normal file
@ -0,0 +1,89 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kSession = "remote_session";
|
||||
|
||||
class RemoteDeviceTest : public ::testing::Test {
|
||||
protected:
|
||||
string remote_name_;
|
||||
std::unique_ptr<WorkerCacheInterface> worker_cache_;
|
||||
std::unique_ptr<WorkerInterface> wi_;
|
||||
std::vector<Device*> devices_;
|
||||
std::unique_ptr<test::TestCluster> cluster_;
|
||||
|
||||
RemoteDeviceTest() {
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 1, &cluster_));
|
||||
const string& hostport = cluster_->targets()[0];
|
||||
string host;
|
||||
int port;
|
||||
CHECK(RE2::FullMatch(hostport, "(.+):(\\d+)", &host, &port));
|
||||
GrpcChannelSpec spec;
|
||||
spec.AddHostPortsJob("localhost", {hostport}, 1);
|
||||
worker_cache_.reset(NewGrpcWorkerCache(NewGrpcChannelCache(spec)));
|
||||
remote_name_ = strings::StrCat("/job:", host, "/replica:0/task:0");
|
||||
wi_.reset(worker_cache_->CreateWorker(remote_name_));
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
Notification n;
|
||||
NewRemoteDevices(Env::Default(), worker_cache_.get(), remote_name_,
|
||||
[&n, this](const Status& s, std::vector<Device*>* found) {
|
||||
TF_CHECK_OK(s);
|
||||
devices_ = *found;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
EXPECT_EQ(devices_.size(), 2);
|
||||
std::sort(devices_.begin(), devices_.end(), [](Device* a, Device* b) {
|
||||
return a->name().compare(b->name()) < 0;
|
||||
});
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
for (auto d : devices_) delete d;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(RemoteDeviceTest, GetStatus) {
|
||||
// We know what the testlib's fake server does.
|
||||
EXPECT_EQ(devices_[0]->name(), strings::StrCat(remote_name_, "/cpu:0"));
|
||||
EXPECT_EQ(devices_[0]->attributes().device_type(),
|
||||
DeviceType(DEVICE_CPU).type());
|
||||
EXPECT_EQ(devices_[0]->attributes().memory_limit(), 256 << 20);
|
||||
EXPECT_EQ(devices_[1]->name(), strings::StrCat(remote_name_, "/cpu:1"));
|
||||
EXPECT_EQ(devices_[1]->attributes().memory_limit(), 256 << 20);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,79 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// RendezvousMgr keeps track of a set of local rendezvous instances.
|
||||
// All tensors sent by this worker are buffered in a RendezvousMgr
|
||||
// until the tensor is received. Each global unique "step_id"
|
||||
// corresponds to one local rendezvous instance managed by a
|
||||
// RendezvousMgr.
|
||||
//
|
||||
// E.g.,
|
||||
// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
|
||||
// fork execution of an graph executor using "rendez" on thread 1;
|
||||
// fork execution of another graph executor using "rendez" on thread 2;
|
||||
// ...
|
||||
// join threads 1 and 2;
|
||||
//
|
||||
// In the example above, execution in thread 1 and 2 communicates with
|
||||
// each other by send/recv operations through the "rend".
|
||||
//
|
||||
// Tensors sent and recved through rendezvous managed by this
|
||||
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
|
||||
class RendezvousMgrInterface {
|
||||
public:
|
||||
RendezvousMgrInterface() {}
|
||||
virtual ~RendezvousMgrInterface() {}
|
||||
|
||||
// Returns Rendezvous supporting send and recv among workers in the
|
||||
// "step_id". The caller takes ownership of one reference on the
|
||||
// returned Rendezvous instance.
|
||||
virtual Rendezvous* Find(int64 step_id) = 0;
|
||||
|
||||
// Finds the local rendezvous instance for the "step_id". Runs
|
||||
// "done" when the tensor for "key" is produced or an error occurs.
|
||||
//
|
||||
// This method is used by the rpc handler of RecvTensor.
|
||||
virtual void RecvLocalAsync(int64 step_id, const string& key,
|
||||
Rendezvous::DoneCallback done) = 0;
|
||||
|
||||
// Synchronous wrapper for RecvLocalAsync.
|
||||
virtual Status RecvLocal(int64 step_id, const string& key, Tensor* val,
|
||||
bool* is_dead) = 0;
|
||||
|
||||
// Removes rendezvous for "step_id".
|
||||
//
|
||||
// TODO(zhifengc): Have a background thread in worker that
|
||||
// periodically calls CleanupAll().
|
||||
virtual void Cleanup(int64 step_id) = 0;
|
||||
|
||||
// Removes all rendezvous.
|
||||
virtual void CleanupAll() = 0;
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_
|
341
tensorflow/core/distributed_runtime/rpc/BUILD
Normal file
341
tensorflow/core/distributed_runtime/rpc/BUILD
Normal file
@ -0,0 +1,341 @@
|
||||
# Description:
|
||||
# RPC communication interfaces and implementations for TensorFlow.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "c_srcs",
|
||||
data = glob([
|
||||
"**/*.cc",
|
||||
"**/*.h",
|
||||
]),
|
||||
)
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cuda_library",
|
||||
"tf_cc_tests",
|
||||
)
|
||||
|
||||
# For platform specific build config
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
|
||||
package(default_visibility = [
|
||||
"//tensorflow:internal",
|
||||
])
|
||||
|
||||
cc_library(
|
||||
name = "grpc_util",
|
||||
srcs = [],
|
||||
hdrs = ["grpc_util.h"],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_client_cq_tag",
|
||||
srcs = [],
|
||||
hdrs = ["grpc_client_cq_tag.h"],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
":grpc_util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_remote_worker",
|
||||
srcs = ["grpc_remote_worker.cc"],
|
||||
hdrs = ["grpc_remote_worker.h"],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
":grpc_client_cq_tag",
|
||||
":grpc_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"//tensorflow/core:worker_service_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:process_util",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_logger",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_channel",
|
||||
srcs = ["grpc_channel.cc"],
|
||||
hdrs = ["grpc_channel.h"],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
":grpc_util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_call",
|
||||
srcs = [],
|
||||
hdrs = ["grpc_call.h"],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "async_service_interface",
|
||||
srcs = [],
|
||||
hdrs = ["async_service_interface.h"],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_worker_cache",
|
||||
srcs = ["grpc_worker_cache.cc"],
|
||||
hdrs = ["grpc_worker_cache.h"],
|
||||
deps = [
|
||||
":grpc_channel",
|
||||
":grpc_client_cq_tag",
|
||||
":grpc_remote_worker",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_logger",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_partial",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_worker_service",
|
||||
srcs = ["grpc_worker_service.cc"],
|
||||
hdrs = ["grpc_worker_service.h"],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
":async_service_interface",
|
||||
":grpc_call",
|
||||
":grpc_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"//tensorflow/core:worker_service_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:graph_mgr",
|
||||
"//tensorflow/core/distributed_runtime:process_util",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_remote_master",
|
||||
srcs = ["grpc_remote_master.cc"],
|
||||
hdrs = ["grpc_remote_master.h"],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
":grpc_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:master_service_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:master_interface",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_master_service",
|
||||
srcs = ["grpc_master_service.cc"],
|
||||
hdrs = ["grpc_master_service.h"],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
":async_service_interface",
|
||||
":grpc_call",
|
||||
":grpc_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:master_service_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:master",
|
||||
"//tensorflow/core/distributed_runtime:master_interface",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rpc_rendezvous_mgr",
|
||||
srcs = ["rpc_rendezvous_mgr.cc"],
|
||||
hdrs = ["rpc_rendezvous_mgr.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime:process_util",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_server_lib",
|
||||
srcs = [
|
||||
"grpc_server_lib.cc",
|
||||
],
|
||||
hdrs = ["grpc_server_lib.h"],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
":async_service_interface",
|
||||
":grpc_channel",
|
||||
":grpc_master_service",
|
||||
":grpc_worker_cache",
|
||||
":grpc_worker_service",
|
||||
":rpc_rendezvous_mgr",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:graph_mgr",
|
||||
"//tensorflow/core/distributed_runtime:master_env",
|
||||
"//tensorflow/core/distributed_runtime:master_session",
|
||||
"//tensorflow/core/distributed_runtime:process_util",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "grpc_tensorflow_server",
|
||||
srcs = [
|
||||
"grpc_tensorflow_server.cc",
|
||||
],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
":grpc_server_lib",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "grpc_testlib_ops",
|
||||
testonly = 1,
|
||||
srcs = ["grpc_testlib_ops.cc"],
|
||||
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
|
||||
deps = [
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "grpc_testlib_server",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"grpc_testlib_server.cc",
|
||||
],
|
||||
deps = [
|
||||
"@grpc//:grpc++_unsecure",
|
||||
":grpc_server_lib",
|
||||
":grpc_testlib_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "grpc_testlib",
|
||||
testonly = 1,
|
||||
srcs = ["grpc_testlib.cc"],
|
||||
hdrs = ["grpc_testlib.h"],
|
||||
data = [
|
||||
":grpc_testlib_server",
|
||||
],
|
||||
deps = [
|
||||
":grpc_session",
|
||||
":grpc_testlib_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow_opensource",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_session",
|
||||
srcs = ["grpc_session.cc"],
|
||||
hdrs = ["grpc_session.h"],
|
||||
deps = [
|
||||
":grpc_channel",
|
||||
":grpc_remote_master",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core/distributed_runtime:master_interface",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags(),
|
||||
tests = [
|
||||
"grpc_channel_test.cc",
|
||||
"grpc_session_test.cc",
|
||||
"rpc_rendezvous_mgr_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":grpc_channel",
|
||||
":grpc_session",
|
||||
":grpc_testlib",
|
||||
":rpc_rendezvous_mgr",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/distributed_runtime:process_util",
|
||||
],
|
||||
)
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Represents an abstract asynchronous service that handles incoming
|
||||
// RPCs with a polling loop.
|
||||
class AsyncServiceInterface {
|
||||
public:
|
||||
virtual ~AsyncServiceInterface() {}
|
||||
|
||||
// A blocking method that should be called to handle incoming RPCs.
|
||||
// This method will block until the service is shutdown, which
|
||||
// depends on the implementation of the service.
|
||||
virtual void HandleRPCsLoop() = 0;
|
||||
|
||||
// TODO(mrry): Add a clean shutdown method?
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
|
227
tensorflow/core/distributed_runtime/rpc/grpc_call.h
Normal file
227
tensorflow/core/distributed_runtime/rpc/grpc_call.h
Normal file
@ -0,0 +1,227 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
#include "external/grpc/include/grpc++/grpc++.h"
|
||||
#include "external/grpc/include/grpc++/server_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// CALL STRUCTURES
|
||||
// ===============
|
||||
//
|
||||
// Each pending (incoming) request corresponds to a call object that
|
||||
// encapsulates the state of the call. Templates and
|
||||
// pointers-to-member functions are used to avoid boilerplate and
|
||||
// redundant closure creation. The class hierarchy is as follows:
|
||||
//
|
||||
// * `UntypedCall<Service>`: The base class represents a call that
|
||||
// could be associated with any of the methods on a service of type
|
||||
// `Service`. Also defines a `Tag` nested class that can be used as
|
||||
// the tag in a `grpc::CompletionQueue`. Each class that
|
||||
// instantiates `Service` should have a completion queue polling
|
||||
// loop that knows about `UntypedCall<Service>::Tag` objects, and
|
||||
// invokes their `OnCompleted()` method to continue processing.
|
||||
//
|
||||
// * `Call<Service, GrpcService, Req, Resp>`: This class extends
|
||||
// `UntypedCall<Service>` and is additionally parameterized by the
|
||||
// gRPC-generated asynchronous service class, and the request and
|
||||
// response message types. It defines the state associated with a
|
||||
// call (whose type depends on the message types), and stores a
|
||||
// pointer to a `Service::HandleFoo()` handler method. Each
|
||||
// `Service::HandleFoo()` method knows about the corresponding
|
||||
// `Call` type, in order to access its state, and invoke its
|
||||
// `SendResponse()` method.
|
||||
//
|
||||
// The lifecycle of a call object is as follows.
|
||||
//
|
||||
// 1. A `Service` creates a `Call` for a particular method and
|
||||
// enqueues it in its completion queue (via an
|
||||
// `UntypedCall<Service>::Tag`).
|
||||
//
|
||||
// 2. When the tag is returned from `cq_->Next()`, the
|
||||
// `UntypedCall::RequestReceived()` method is invoked and takes
|
||||
// ownership of the call object. This indirectly invokes the
|
||||
// appropriate handler method on `Service`.
|
||||
//
|
||||
// 3. After the response has been written (perhaps in another thread),
|
||||
// the `Call::SendResponse()` method is invoked. It transfers
|
||||
// ownership of the call object back to the completion queue (via
|
||||
// an `UntypedCall::Tag`).
|
||||
//
|
||||
// 4. When the response has been sent, the tag is returned from
|
||||
// `cq_->Next()`, and the call object is deleted.
|
||||
|
||||
// Represents a pending request with unknown message types.
|
||||
template <class Service>
|
||||
class UntypedCall : public core::RefCounted {
|
||||
public:
|
||||
virtual ~UntypedCall() {}
|
||||
|
||||
// The implementation of this method should use `service` to handle
|
||||
// an incoming request, and (perhaps asynchronously) send the
|
||||
// response.
|
||||
//
|
||||
// One reference on `this` is transferred to the callee, and the
|
||||
// callee is responsible for releasing it (typically via
|
||||
// `Call::SendResponse()`).
|
||||
//
|
||||
// `ok` is true if the request was received in a "regular event",
|
||||
// otherwise false.
|
||||
virtual void RequestReceived(Service* service, bool ok) = 0;
|
||||
|
||||
// This method will be called when the response has been sent by
|
||||
// `service` and the call is no longer used.
|
||||
//
|
||||
// `ok` is true if the response sending completed as a "regular
|
||||
// event", otherwise it is false.
|
||||
void ResponseSent(Service* service, bool ok) {}
|
||||
|
||||
// This method will be called either (i) when the server is notified
|
||||
// that the request has been cancelled, or (ii) when the request completes
|
||||
// normally. The implementation should distinguish these cases by querying
|
||||
// the `grpc::ServerContext` associated with the request.
|
||||
virtual void RequestCancelled(Service* service, bool ok) = 0;
|
||||
|
||||
// Associates a tag in a `::grpc::CompletionQueue` with a callback
|
||||
// for an incoming RPC. A Tag owns a reference on the corresponding
|
||||
// Call object.
|
||||
class Tag {
|
||||
public:
|
||||
using Callback = void (UntypedCall::*)(Service*, bool);
|
||||
|
||||
// Creates a new `Tag` for the given `UntypedCall`. When the
|
||||
// request associated with this tag is complete, `callback` will
|
||||
// be called.
|
||||
Tag(UntypedCall* call, Callback callback)
|
||||
: call_(call), callback_(callback) {
|
||||
call_->Ref();
|
||||
}
|
||||
|
||||
~Tag() { call_->Unref(); }
|
||||
|
||||
// Calls the callback associated with this tag.
|
||||
//
|
||||
// The callback takes ownership of `this->call_`.
|
||||
void OnCompleted(Service* service, bool ok) {
|
||||
(call_->*callback_)(service, ok);
|
||||
}
|
||||
|
||||
private:
|
||||
UntypedCall* call_; // `this` owns one reference.
|
||||
Callback callback_;
|
||||
};
|
||||
};
|
||||
|
||||
// Represents a pending call with known request and response message
|
||||
// types, and a known request-handling method.
|
||||
template <class Service, class GrpcService, class RequestMessage,
|
||||
class ResponseMessage>
|
||||
class Call : public UntypedCall<Service> {
|
||||
public:
|
||||
// Represents the generic signature of a generated
|
||||
// `GrpcService::RequestFoo()` method, where `Foo` is the name of an
|
||||
// RPC method.
|
||||
using EnqueueFunction = void (GrpcService::*)(
|
||||
::grpc::ServerContext*, RequestMessage*,
|
||||
::grpc::ServerAsyncResponseWriter<ResponseMessage>*,
|
||||
::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*);
|
||||
|
||||
// Represents the generic signature of a `Service::HandleFoo()`
|
||||
// method, where `Foo` is the name of an RPC method.
|
||||
using HandleRequestFunction = void (Service::*)(
|
||||
Call<Service, GrpcService, RequestMessage, ResponseMessage>*);
|
||||
|
||||
Call(HandleRequestFunction handle_request_function)
|
||||
: handle_request_function_(handle_request_function), responder_(&ctx_) {}
|
||||
|
||||
virtual ~Call() {}
|
||||
|
||||
void RequestReceived(Service* service, bool ok) override {
|
||||
if (ok) {
|
||||
this->Ref();
|
||||
(service->*handle_request_function_)(this);
|
||||
}
|
||||
}
|
||||
|
||||
void SendResponse(::grpc::Status status) {
|
||||
responder_.Finish(response, status,
|
||||
new typename UntypedCall<Service>::Tag(
|
||||
this, &UntypedCall<Service>::ResponseSent));
|
||||
this->Unref();
|
||||
}
|
||||
|
||||
void RequestCancelled(Service* service, bool ok) override {
|
||||
if (ctx_.IsCancelled()) {
|
||||
mutex_lock l(mu_);
|
||||
if (cancel_callback_) {
|
||||
cancel_callback_();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Registers `callback` as the function that should be called if and when this
|
||||
// call is cancelled by the client.
|
||||
void SetCancelCallback(std::function<void()> callback) {
|
||||
mutex_lock l(mu_);
|
||||
cancel_callback_ = callback;
|
||||
}
|
||||
|
||||
// Clears any cancellation callback that has been registered for this call.
|
||||
void ClearCancelCallback() {
|
||||
mutex_lock l(mu_);
|
||||
cancel_callback_ = nullptr;
|
||||
}
|
||||
|
||||
// Enqueues a new request for the given service on the given
|
||||
// completion queue, using the given `enqueue_function`.
|
||||
//
|
||||
// The request will be handled with the given
|
||||
// `handle_request_function`.
|
||||
static void EnqueueRequest(GrpcService* grpc_service,
|
||||
::grpc::ServerCompletionQueue* cq,
|
||||
EnqueueFunction enqueue_function,
|
||||
HandleRequestFunction handle_request_function) {
|
||||
auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
|
||||
handle_request_function);
|
||||
|
||||
call->ctx_.AsyncNotifyWhenDone(new typename UntypedCall<Service>::Tag(
|
||||
call, &UntypedCall<Service>::RequestCancelled));
|
||||
|
||||
(grpc_service->*enqueue_function)(
|
||||
&call->ctx_, &call->request, &call->responder_, cq, cq,
|
||||
new typename UntypedCall<Service>::Tag(
|
||||
call, &UntypedCall<Service>::RequestReceived));
|
||||
call->Unref();
|
||||
}
|
||||
|
||||
RequestMessage request;
|
||||
ResponseMessage response;
|
||||
|
||||
private:
|
||||
HandleRequestFunction handle_request_function_;
|
||||
::grpc::ServerContext ctx_;
|
||||
::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
|
||||
mutex mu_;
|
||||
std::function<void()> cancel_callback_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
|
314
tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
Normal file
314
tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
Normal file
@ -0,0 +1,314 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "external/grpc/include/grpc++/create_channel.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
RE2* kTargetRE = new RE2("^/job:([^/]+)/replica:([0-9]+)/task:([0-9]+)$");
|
||||
RE2* kHostPortRE = new RE2("([^:/]+):(\\d+)");
|
||||
RE2* kSparseHostPortRE = new RE2("(\\d+):([^:/]+):(\\d+)");
|
||||
|
||||
string MakeAddress(const string& job, int replica, int task) {
|
||||
return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SharedGrpcChannelPtr NewHostPortGrpcChannel(const string& target) {
|
||||
// TODO(mrry): Implement secure channels.
|
||||
return ::grpc::CreateChannel(target, ::grpc::InsecureChannelCredentials());
|
||||
}
|
||||
|
||||
Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
|
||||
const std::vector<string>& host_ports,
|
||||
int tasks_per_replica) {
|
||||
if (!job_ids_.insert(job_id).second) {
|
||||
return errors::InvalidArgument(
|
||||
"Duplicate job ID in cluster specification: ", job_id);
|
||||
}
|
||||
HostPortsJob job;
|
||||
job.job_id = job_id;
|
||||
for (const string& host_port : host_ports) {
|
||||
string host;
|
||||
int port;
|
||||
if (!RE2::FullMatch(host_port, *kHostPortRE, &host, &port)) {
|
||||
return errors::InvalidArgument("Could not interpret \"", host_port,
|
||||
"\" as a host-port pair.");
|
||||
}
|
||||
}
|
||||
job.host_ports = host_ports;
|
||||
job.tasks_per_replica = tasks_per_replica;
|
||||
host_ports_jobs_.push_back(job);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec) {
|
||||
const int num_jobs = spec.host_ports_jobs().size();
|
||||
if (!num_jobs) {
|
||||
LOG(ERROR) << "Empty channel spec.";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<GrpcChannelCache*> caches;
|
||||
caches.reserve(num_jobs);
|
||||
for (const GrpcChannelSpec::HostPortsJob& job : spec.host_ports_jobs()) {
|
||||
caches.push_back(NewHostPortsGrpcChannelCache(job.job_id, job.host_ports,
|
||||
job.tasks_per_replica));
|
||||
}
|
||||
return caches.size() == 1 ? caches[0] : NewMultiGrpcChannelCache(caches);
|
||||
}
|
||||
|
||||
// GrpcChannelCache that caches results to FindWorkerChannel() calls.
|
||||
class CachingGrpcChannelCache : public GrpcChannelCache {
|
||||
public:
|
||||
CachingGrpcChannelCache() {}
|
||||
|
||||
~CachingGrpcChannelCache() override {}
|
||||
|
||||
SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
|
||||
SharedGrpcChannelPtr ch = nullptr;
|
||||
{
|
||||
mutex_lock l(mu_); // could use reader lock
|
||||
ch = gtl::FindPtrOrNull(channels_, target);
|
||||
if (ch) {
|
||||
return ch;
|
||||
}
|
||||
}
|
||||
ch = FindChannelOnce(target);
|
||||
if (ch) {
|
||||
mutex_lock l(mu_);
|
||||
channels_.insert({target, ch});
|
||||
}
|
||||
return ch;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Find the ClientChannel for "target". Only called when no channel was
|
||||
// found in the channels_ cache for "target". A non nullptr result will be
|
||||
// cached in channels_.
|
||||
virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
|
||||
|
||||
private:
|
||||
// TODO(zhifengc): Eviction when the map becomes too big.
|
||||
mutex mu_;
|
||||
std::unordered_map<string, SharedGrpcChannelPtr> channels_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// A ChannelCache that is the union of multiple ChannelCaches.
|
||||
// Takes ownership of the caches passed to the constructor.
|
||||
class MultiGrpcChannelCache : public CachingGrpcChannelCache {
|
||||
public:
|
||||
explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches)
|
||||
: CachingGrpcChannelCache(), caches_(caches) {}
|
||||
|
||||
~MultiGrpcChannelCache() override {
|
||||
for (GrpcChannelCache* cache : caches_) {
|
||||
delete cache;
|
||||
}
|
||||
}
|
||||
|
||||
void ListWorkers(std::vector<string>* workers) override {
|
||||
for (GrpcChannelCache* cache : caches_) {
|
||||
cache->ListWorkers(workers);
|
||||
}
|
||||
}
|
||||
|
||||
string TranslateTask(const string& target) override {
|
||||
mutex_lock l(mu_); // could use reader lock
|
||||
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
|
||||
if (cache == nullptr) {
|
||||
for (GrpcChannelCache* c : caches_) {
|
||||
string r = c->TranslateTask(target);
|
||||
if (!r.empty()) {
|
||||
target_caches_.insert({target, c});
|
||||
cache = c;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
|
||||
<< target;
|
||||
return cache->TranslateTask(target);
|
||||
}
|
||||
|
||||
protected:
|
||||
SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
|
||||
for (GrpcChannelCache* cache : caches_) {
|
||||
SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
|
||||
if (ch) {
|
||||
mutex_lock l(mu_);
|
||||
target_caches_.insert({target, cache});
|
||||
return ch;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
// List of channels used by this MultiGrpcChannelCache.
|
||||
const std::vector<GrpcChannelCache*> caches_;
|
||||
|
||||
mutex mu_;
|
||||
// Cache of channels keyed by the target they are handling.
|
||||
// The same GrpcChannelCache can appear multiple times in the cache.
|
||||
std::unordered_map<string, GrpcChannelCache*> target_caches_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
GrpcChannelCache* NewMultiGrpcChannelCache(
|
||||
const std::vector<GrpcChannelCache*>& caches) {
|
||||
return new MultiGrpcChannelCache(caches);
|
||||
}
|
||||
|
||||
class HostPortsGrpcChannelCache : public CachingGrpcChannelCache {
|
||||
public:
|
||||
HostPortsGrpcChannelCache(const string& job_id,
|
||||
const std::vector<string>& host_ports,
|
||||
int tasks_per_replica)
|
||||
: job_id_(job_id),
|
||||
host_ports_(BuildDenseHostPortsList(host_ports, tasks_per_replica)),
|
||||
tasks_per_replica_(tasks_per_replica) {
|
||||
LOG(INFO) << "Initialize HostPortsGrpcChannelCache for job " << job_id
|
||||
<< " -> {" << str_util::Join(host_ports, ", ") << "}";
|
||||
}
|
||||
~HostPortsGrpcChannelCache() override {}
|
||||
|
||||
void ListWorkers(std::vector<string>* workers) override {
|
||||
int num_host_ports = 0;
|
||||
for (size_t i = 0; i < host_ports_.size(); ++i) {
|
||||
if (!host_ports_[i].empty()) {
|
||||
++num_host_ports;
|
||||
}
|
||||
}
|
||||
workers->reserve(workers->size() + num_host_ports);
|
||||
for (size_t i = 0; i < host_ports_.size(); ++i) {
|
||||
if (!host_ports_[i].empty()) {
|
||||
workers->emplace_back(MakeAddress(job_id_, i / tasks_per_replica_,
|
||||
i % tasks_per_replica_));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string TranslateTask(const string& target) override {
|
||||
RegexpStringPiece job;
|
||||
int32 replica;
|
||||
int32 task;
|
||||
if (!RE2::FullMatch(target, *kTargetRE, &job, &replica, &task)) {
|
||||
LOG(WARNING) << "Invalid target: " << target;
|
||||
return "";
|
||||
}
|
||||
if (job != job_id_) {
|
||||
return "";
|
||||
}
|
||||
if (task >= tasks_per_replica_) {
|
||||
LOG(WARNING) << "Task out of bounds for job " << job_id_ << ": " << task;
|
||||
return "";
|
||||
}
|
||||
const size_t i = replica * tasks_per_replica_ + task;
|
||||
if (i >= host_ports_.size()) {
|
||||
LOG(WARNING) << "Replica/task out of bounds for job " << job_id_ << ": "
|
||||
<< target;
|
||||
return "";
|
||||
}
|
||||
if (host_ports_[i].empty()) {
|
||||
LOG(WARNING) << "Replica/task not in sparse index:host:port list for job "
|
||||
<< job_id_ << ": " << target;
|
||||
return "";
|
||||
}
|
||||
return host_ports_[i];
|
||||
}
|
||||
|
||||
protected:
|
||||
static std::vector<string> BuildDenseHostPortsList(
|
||||
const std::vector<string>& host_ports, int tasks_per_replica) {
|
||||
std::map<int, string> sparse_host_ports;
|
||||
for (const string& host_port : host_ports) {
|
||||
int i = -1;
|
||||
string host;
|
||||
int port = -1;
|
||||
if (RE2::FullMatch(host_port, *kSparseHostPortRE, &i, &host, &port)) {
|
||||
CHECK_LE(0, i);
|
||||
CHECK_LE(0, port);
|
||||
CHECK(sparse_host_ports.find(i) == sparse_host_ports.end())
|
||||
<< "Duplicate index " << i << ": {"
|
||||
<< str_util::Join(host_ports, ", ") << "}";
|
||||
sparse_host_ports[i] = strings::StrCat(host, ":", port);
|
||||
} else {
|
||||
CHECK(RE2::FullMatch(host_port, *kHostPortRE, &host, &port))
|
||||
<< host_port
|
||||
<< " does not look like a host:port or an index:host:port";
|
||||
}
|
||||
}
|
||||
|
||||
if (sparse_host_ports.empty()) {
|
||||
// The input is a dense list; return it directly.
|
||||
return host_ports;
|
||||
}
|
||||
|
||||
// The input is a sparse list. Convert it to a dense list.
|
||||
CHECK_EQ(host_ports.size(), sparse_host_ports.size())
|
||||
<< "Mix of host:port and index:host:port: {"
|
||||
<< str_util::Join(host_ports, ", ") << "}";
|
||||
int num_tasks = sparse_host_ports.rbegin()->first + 1;
|
||||
if (num_tasks % tasks_per_replica != 0) {
|
||||
num_tasks = (num_tasks / tasks_per_replica + 1) * tasks_per_replica;
|
||||
}
|
||||
std::vector<string> dense_host_ports;
|
||||
dense_host_ports.resize(num_tasks);
|
||||
for (const auto& p : sparse_host_ports) {
|
||||
dense_host_ports[p.first] = p.second;
|
||||
}
|
||||
return dense_host_ports;
|
||||
}
|
||||
|
||||
SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
|
||||
const string host_port = TranslateTask(target);
|
||||
if (host_port.empty()) {
|
||||
LOG(WARNING) << "Could not find channel for target: " << target;
|
||||
return nullptr;
|
||||
}
|
||||
return NewHostPortGrpcChannel(host_port);
|
||||
}
|
||||
|
||||
private:
|
||||
const string job_id_;
|
||||
const std::vector<string> host_ports_;
|
||||
const int tasks_per_replica_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HostPortsGrpcChannelCache);
|
||||
};
|
||||
|
||||
GrpcChannelCache* NewHostPortsGrpcChannelCache(
|
||||
const string& job_id, const std::vector<string>& host_ports,
|
||||
int tasks_per_replica) {
|
||||
return new HostPortsGrpcChannelCache(job_id, host_ports, tasks_per_replica);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
98
tensorflow/core/distributed_runtime/rpc/grpc_channel.h
Normal file
98
tensorflow/core/distributed_runtime/rpc/grpc_channel.h
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "external/grpc/include/grpc++/grpc++.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Consolidated parameter structure to ease use of generic interfaces.
|
||||
//
|
||||
// Each job_id requires:
|
||||
// - a list of host:port (or sparse list of index:host:port)
|
||||
// - the number of tasks per replica
|
||||
class GrpcChannelSpec {
|
||||
public:
|
||||
struct HostPortsJob {
|
||||
string job_id;
|
||||
std::vector<string> host_ports;
|
||||
int tasks_per_replica;
|
||||
};
|
||||
|
||||
Status AddHostPortsJob(const string& job_id,
|
||||
const std::vector<string>& host_ports,
|
||||
int tasks_per_replica);
|
||||
|
||||
const std::vector<HostPortsJob>& host_ports_jobs() const {
|
||||
return host_ports_jobs_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<HostPortsJob> host_ports_jobs_;
|
||||
std::set<string> job_ids_;
|
||||
};
|
||||
|
||||
class GrpcChannelCache {
|
||||
public:
|
||||
virtual ~GrpcChannelCache() {}
|
||||
|
||||
// Populates *workers with names of all workers which this object
|
||||
// was created to handle. Worker names are in the format
|
||||
// /job:<job identifier>/task:<task id>
|
||||
// e.g. /job:mnist/task:2
|
||||
virtual void ListWorkers(std::vector<string>* workers) = 0;
|
||||
|
||||
// If found, returns a gRPC channel that is connected to the remote
|
||||
// worker named by 'target'. 'target' is of the following
|
||||
// format: /job:<job identifier>/task:<task id>
|
||||
// E.g., /job:mnist/task:2
|
||||
virtual SharedGrpcChannelPtr FindWorkerChannel(const string& target) = 0;
|
||||
|
||||
// Translates a string in the form `/job:X/task:Z` into a host_port.
|
||||
virtual string TranslateTask(const string& task) = 0;
|
||||
};
|
||||
|
||||
GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& p);
|
||||
|
||||
// Below here are internal-only functions.
|
||||
|
||||
SharedGrpcChannelPtr NewHostPortGrpcChannel(const string& target);
|
||||
|
||||
// Returns a ChannelCache that uses a set of known host:port pairs. E.g., say,
|
||||
// job_id = 'mnist', 'host_ports' = {"h0:0", "h1:1", ..., "h11:11", "h12:12"},
|
||||
// tasks_per_replica = 8, /job:mnist/replica:1/task:3 is mapped to host:port
|
||||
// "h11:11" (11 = 8 * 1 + 3).
|
||||
//
|
||||
// The caller takes ownership of the returned object.
|
||||
GrpcChannelCache* NewHostPortsGrpcChannelCache(
|
||||
const string& job_id, const std::vector<string>& host_ports,
|
||||
int tasks_per_replica);
|
||||
|
||||
// Returns a ChannelCache that is the union of a number of other ChannelCaches.
|
||||
GrpcChannelCache* NewMultiGrpcChannelCache(
|
||||
const std::vector<GrpcChannelCache*>& caches);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
|
137
tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
Normal file
137
tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
Normal file
@ -0,0 +1,137 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
#define IsSameAddrSp DeviceNameUtils::IsSameAddressSpace
|
||||
|
||||
TEST(GrpcChannelTest, IsSameAddressSpace) {
|
||||
// Same.
|
||||
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0",
|
||||
"/job:mnist/replica:10/task:10/cpu:1"));
|
||||
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0",
|
||||
"/job:mnist/replica:10/task:10/gpu:2"));
|
||||
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10",
|
||||
"/job:mnist/replica:10/task:10/gpu:2"));
|
||||
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:1",
|
||||
"/job:mnist/replica:10/task:10"));
|
||||
|
||||
// Different.
|
||||
EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:10/task:9/cpu:0",
|
||||
"/job:mnist/replica:10/task:10/cpu:0"));
|
||||
EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:9/task:10/cpu:0",
|
||||
"/job:mnist/replica:10/task:10/cpu:0"));
|
||||
EXPECT_FALSE(IsSameAddrSp("/job:MNIST/replica:10/task:10/cpu:0",
|
||||
"/job:mnist/replica:10/task:10/cpu:0"));
|
||||
|
||||
// Invalid names.
|
||||
EXPECT_FALSE(IsSameAddrSp("random_invalid_target", "random_invalid_target"));
|
||||
EXPECT_FALSE(IsSameAddrSp("/job:/replica:10/task:10/cpu:0",
|
||||
"/job:/replica:10/task:10/cpu:1"));
|
||||
EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:xx/task:10/cpu:0",
|
||||
"/job:mnist/replica:xx/task:10/cpu:1"));
|
||||
EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:10/task:yy/cpu:0",
|
||||
"/job:mnist/replica:10/task:yy/cpu:1"));
|
||||
}
|
||||
|
||||
TEST(GrpcChannelTest, HostPorts) {
|
||||
std::unique_ptr<GrpcChannelCache> cc(NewHostPortsGrpcChannelCache(
|
||||
"mnist", {"a:1", "b:2", "c:3", "d:4", "e:5", "f:6"}, 2));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("invalid_target"));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:other/replica:0/task:0"));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:2"));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:3/task:0"));
|
||||
|
||||
{
|
||||
// NOTE(mrry): The gRPC channel doesn't expose the target, so we
|
||||
// can't compare it for equality.
|
||||
auto a_1_1 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
|
||||
auto a_1_2 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
|
||||
|
||||
auto d_4_1 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
|
||||
auto d_4_2 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
|
||||
|
||||
auto e_5_1 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
|
||||
auto e_5_2 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
|
||||
|
||||
EXPECT_EQ(a_1_1.get(), a_1_2.get());
|
||||
EXPECT_EQ(d_4_1.get(), d_4_2.get());
|
||||
EXPECT_EQ(e_5_1.get(), e_5_2.get());
|
||||
|
||||
EXPECT_NE(a_1_1.get(), d_4_2.get());
|
||||
EXPECT_NE(a_1_1.get(), e_5_2.get());
|
||||
EXPECT_NE(d_4_1.get(), e_5_2.get());
|
||||
}
|
||||
|
||||
std::vector<string> workers;
|
||||
cc->ListWorkers(&workers);
|
||||
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
|
||||
"/job:mnist/replica:0/task:1",
|
||||
"/job:mnist/replica:1/task:0",
|
||||
"/job:mnist/replica:1/task:1",
|
||||
"/job:mnist/replica:2/task:0",
|
||||
"/job:mnist/replica:2/task:1"}),
|
||||
workers);
|
||||
}
|
||||
|
||||
TEST(GrpcChannelTest, SparseHostPorts) {
|
||||
std::unique_ptr<GrpcChannelCache> cc(
|
||||
NewHostPortsGrpcChannelCache("mnist", {"0:a:1", "3:d:4", "4:e:5"}, 2));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("invalid_target"));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:other/replica:0/task:0"));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:1"));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:2"));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:1/task:0"));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:2/task:1"));
|
||||
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:3/task:0"));
|
||||
|
||||
{
|
||||
// NOTE(mrry): The gRPC channel doesn't expose the target, so we
|
||||
// can't compare it for equality.
|
||||
auto a_1_1 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
|
||||
auto a_1_2 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
|
||||
|
||||
auto d_4_1 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
|
||||
auto d_4_2 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
|
||||
|
||||
auto e_5_1 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
|
||||
auto e_5_2 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
|
||||
|
||||
EXPECT_EQ(a_1_1.get(), a_1_2.get());
|
||||
EXPECT_EQ(d_4_1.get(), d_4_2.get());
|
||||
EXPECT_EQ(e_5_1.get(), e_5_2.get());
|
||||
|
||||
EXPECT_NE(a_1_1.get(), d_4_2.get());
|
||||
EXPECT_NE(a_1_1.get(), e_5_2.get());
|
||||
EXPECT_NE(d_4_1.get(), e_5_2.get());
|
||||
}
|
||||
|
||||
std::vector<string> workers;
|
||||
cc->ListWorkers(&workers);
|
||||
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
|
||||
"/job:mnist/replica:1/task:1",
|
||||
"/job:mnist/replica:2/task:0"}),
|
||||
workers);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
56
tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
Normal file
56
tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
|
||||
|
||||
#include "external/grpc/include/grpc++/grpc++.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Represents a pending asynchronous client call as a tag that can be
|
||||
// stored in a `grpc::CompletionQueue`.
|
||||
class GrpcClientCQTag {
|
||||
public:
|
||||
GrpcClientCQTag(::grpc::ClientContext* context, StatusCallback cb)
|
||||
: context_(context), cb_(cb) {}
|
||||
~GrpcClientCQTag() { delete context_; }
|
||||
|
||||
void OnCompleted(bool ok) {
|
||||
if (!ok) {
|
||||
VLOG(2) << "Call returned with non-ok status: "
|
||||
<< status_.error_message();
|
||||
}
|
||||
cb_(FromGrpcStatus(status_));
|
||||
}
|
||||
|
||||
::grpc::ClientContext* context() { return context_; }
|
||||
::grpc::Status* status() { return &status_; }
|
||||
|
||||
private:
|
||||
::grpc::ClientContext* context_;
|
||||
::grpc::Status status_;
|
||||
StatusCallback cb_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcClientCQTag);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
|
181
tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
Normal file
181
tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
Normal file
@ -0,0 +1,181 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// GrpcMasterService implements the RPC service MasterSerivce.
|
||||
//
|
||||
// A GrpcMasterService maintains the state of live graph computation
|
||||
// sessions, each session orchestrates both local and remote devices
|
||||
// to carry out the graph computation.
|
||||
//
|
||||
// A GrpcMasterService knows ahead of time local devices available as
|
||||
// client devices.
|
||||
//
|
||||
// A GrpcMasterService discovers remote devices in the background and
|
||||
// keeps track of statistics of those remote devices.
|
||||
//
|
||||
// Each session analyses the graph, places nodes across available
|
||||
// devices, and ultimately drives the graph computation by initiating
|
||||
// RunGraph on workers.
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
|
||||
|
||||
#include "external/grpc/include/grpc++/server_builder.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/master.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GrpcMasterService : public AsyncServiceInterface {
|
||||
public:
|
||||
GrpcMasterService(MasterEnv* env, ::grpc::ServerBuilder* builder)
|
||||
: master_impl_(new Master(env, 0.0)) {
|
||||
builder->RegisterService(&master_service_);
|
||||
cq_ = builder->AddCompletionQueue().release();
|
||||
}
|
||||
|
||||
~GrpcMasterService() {
|
||||
delete cq_;
|
||||
delete master_impl_;
|
||||
}
|
||||
|
||||
// This macro creates a new request for the given RPC method name
|
||||
// (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on
|
||||
// `this->cq_`.
|
||||
//
|
||||
// This macro is invoked one or more times for each RPC method to
|
||||
// ensure that there are sufficient completion queue entries to
|
||||
// handle incoming requests without blocking.
|
||||
//
|
||||
// The implementation of the request handler for each RPC method
|
||||
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
||||
// to keep accepting new requests.
|
||||
#define ENQUEUE_REQUEST(method) \
|
||||
do { \
|
||||
Call<GrpcMasterService, grpc::MasterService::AsyncService, \
|
||||
method##Request, method##Response>:: \
|
||||
EnqueueRequest(&master_service_, cq_, \
|
||||
&grpc::MasterService::AsyncService::Request##method, \
|
||||
&GrpcMasterService::method##Handler); \
|
||||
} while (0)
|
||||
|
||||
void HandleRPCsLoop() {
|
||||
ENQUEUE_REQUEST(CreateSession);
|
||||
ENQUEUE_REQUEST(ExtendSession);
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
ENQUEUE_REQUEST(RunStep);
|
||||
}
|
||||
ENQUEUE_REQUEST(CloseSession);
|
||||
ENQUEUE_REQUEST(ListDevices);
|
||||
ENQUEUE_REQUEST(Reset);
|
||||
|
||||
void* tag;
|
||||
bool ok;
|
||||
while (cq_->Next(&tag, &ok)) {
|
||||
CHECK(ok);
|
||||
UntypedCall<GrpcMasterService>::Tag* callback_tag =
|
||||
static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag);
|
||||
callback_tag->OnCompleted(this, ok);
|
||||
delete callback_tag;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Master* master_impl_; // Owned.
|
||||
::grpc::ServerCompletionQueue* cq_; // Owned.
|
||||
grpc::MasterService::AsyncService master_service_;
|
||||
|
||||
template <class RequestMessage, class ResponseMessage>
|
||||
using MasterCall = Call<GrpcMasterService, grpc::MasterService::AsyncService,
|
||||
RequestMessage, ResponseMessage>;
|
||||
|
||||
// RPC handler for creating a session.
|
||||
void CreateSessionHandler(
|
||||
MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
|
||||
master_impl_->CreateSession(&call->request, &call->response,
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(CreateSession);
|
||||
}
|
||||
|
||||
// RPC handler for extending a session.
|
||||
void ExtendSessionHandler(
|
||||
MasterCall<ExtendSessionRequest, ExtendSessionResponse>* call) {
|
||||
master_impl_->ExtendSession(&call->request, &call->response,
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(ExtendSession);
|
||||
}
|
||||
|
||||
// RPC handler for running one step in a session.
|
||||
void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
|
||||
CallOptions* call_opts = new CallOptions;
|
||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||
master_impl_->RunStep(call_opts, &call->request, &call->response,
|
||||
[call, call_opts](const Status& status) {
|
||||
call->ClearCancelCallback();
|
||||
delete call_opts;
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(RunStep);
|
||||
}
|
||||
|
||||
// RPC handler for deleting a session.
|
||||
void CloseSessionHandler(
|
||||
MasterCall<CloseSessionRequest, CloseSessionResponse>* call) {
|
||||
master_impl_->CloseSession(&call->request, &call->response,
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(CloseSession);
|
||||
}
|
||||
|
||||
// RPC handler for listing devices.
|
||||
void ListDevicesHandler(
|
||||
MasterCall<ListDevicesRequest, ListDevicesResponse>* call) {
|
||||
master_impl_->ListDevices(&call->request, &call->response,
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(ListDevices);
|
||||
}
|
||||
|
||||
// RPC handler for resetting all sessions.
|
||||
void ResetHandler(MasterCall<ResetRequest, ResetResponse>* call) {
|
||||
master_impl_->Reset(&call->request, &call->response,
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(Reset);
|
||||
}
|
||||
#undef ENQUEUE_REQUEST
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
|
||||
};
|
||||
|
||||
AsyncServiceInterface* NewGrpcMasterService(MasterEnv* env,
|
||||
::grpc::ServerBuilder* builder) {
|
||||
CHECK(!env->local_devices.empty());
|
||||
return new GrpcMasterService(env, builder);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
@ -0,0 +1,33 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
|
||||
|
||||
namespace grpc {
|
||||
class ServerBuilder;
|
||||
} // namespace grpc
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class AsyncServiceInterface;
|
||||
class MasterEnv;
|
||||
|
||||
AsyncServiceInterface* NewGrpcMasterService(MasterEnv* env,
|
||||
::grpc::ServerBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
|
@ -0,0 +1,79 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/master_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// GrpcRemoteMaster is an implementation of the MasterInterface
|
||||
// that uses gRPC to talk to the Master service.
|
||||
class GrpcRemoteMaster : public MasterInterface {
|
||||
public:
|
||||
explicit GrpcRemoteMaster(SharedGrpcChannelPtr client_channel)
|
||||
: stub_(grpc::MasterService::NewStub(client_channel)) {}
|
||||
|
||||
~GrpcRemoteMaster() override {}
|
||||
|
||||
Status CreateSession(const CreateSessionRequest* request,
|
||||
CreateSessionResponse* response) override {
|
||||
::grpc::ClientContext ctx;
|
||||
return FromGrpcStatus(stub_->CreateSession(&ctx, *request, response));
|
||||
}
|
||||
|
||||
Status ExtendSession(const ExtendSessionRequest* request,
|
||||
ExtendSessionResponse* response) override {
|
||||
::grpc::ClientContext ctx;
|
||||
return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response));
|
||||
}
|
||||
|
||||
Status RunStep(const RunStepRequest* request,
|
||||
RunStepResponse* response) override {
|
||||
::grpc::ClientContext ctx;
|
||||
return FromGrpcStatus(stub_->RunStep(&ctx, *request, response));
|
||||
}
|
||||
|
||||
Status CloseSession(const CloseSessionRequest* request,
|
||||
CloseSessionResponse* response) override {
|
||||
::grpc::ClientContext ctx;
|
||||
return FromGrpcStatus(stub_->CloseSession(&ctx, *request, response));
|
||||
}
|
||||
|
||||
Status ListDevices(const ListDevicesRequest* request,
|
||||
ListDevicesResponse* response) override {
|
||||
::grpc::ClientContext ctx;
|
||||
return FromGrpcStatus(stub_->ListDevices(&ctx, *request, response));
|
||||
}
|
||||
|
||||
Status Reset(const ResetRequest* request, ResetResponse* response) override {
|
||||
::grpc::ClientContext ctx;
|
||||
return FromGrpcStatus(stub_->Reset(&ctx, *request, response));
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<grpc::MasterService::Stub> stub_;
|
||||
};
|
||||
|
||||
MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel) {
|
||||
return new GrpcRemoteMaster(channel);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
27
tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h
Normal file
27
tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/master_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Returns a MasterInterface wrapped around the gRPC channel `channel`.
|
||||
MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
|
203
tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
Normal file
203
tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
Normal file
@ -0,0 +1,203 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
|
||||
|
||||
#include "external/grpc/include/grpc++/grpc++.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
#include "tensorflow/core/protobuf/worker_service.grpc.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GrpcRemoteWorker : public WorkerInterface {
|
||||
public:
|
||||
explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
||||
::grpc::CompletionQueue* completion_queue,
|
||||
WorkerCacheLogger* logger)
|
||||
: stub_(grpc::WorkerService::NewStub(channel)),
|
||||
cq_(completion_queue),
|
||||
logger_(logger) {}
|
||||
|
||||
~GrpcRemoteWorker() override {}
|
||||
|
||||
void GetStatusAsync(const GetStatusRequest* request,
|
||||
GetStatusResponse* response,
|
||||
StatusCallback done) override {
|
||||
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncGetStatus,
|
||||
done);
|
||||
}
|
||||
|
||||
void RegisterGraphAsync(const RegisterGraphRequest* request,
|
||||
RegisterGraphResponse* response,
|
||||
StatusCallback done) override {
|
||||
IssueRequest(request, response,
|
||||
&grpc::WorkerService::Stub::AsyncRegisterGraph, done);
|
||||
}
|
||||
|
||||
void DeregisterGraphAsync(const DeregisterGraphRequest* request,
|
||||
DeregisterGraphResponse* response,
|
||||
StatusCallback done) override {
|
||||
IssueRequest(request, response,
|
||||
&grpc::WorkerService::Stub::AsyncDeregisterGraph, done);
|
||||
}
|
||||
|
||||
void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
|
||||
RunGraphResponse* response, StatusCallback done) override {
|
||||
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncRunGraph,
|
||||
done, call_opts);
|
||||
}
|
||||
|
||||
void CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||
CleanupGraphResponse* response,
|
||||
StatusCallback done) override {
|
||||
IssueRequest(request, response,
|
||||
&grpc::WorkerService::Stub::AsyncCleanupGraph, done);
|
||||
}
|
||||
|
||||
void CleanupAllAsync(const CleanupAllRequest* request,
|
||||
CleanupAllResponse* response,
|
||||
StatusCallback done) override {
|
||||
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncCleanupAll,
|
||||
done);
|
||||
}
|
||||
|
||||
void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
|
||||
RecvTensorResponse* response,
|
||||
TensorBufAllocator allocator,
|
||||
StatusCallback done) override {
|
||||
VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
|
||||
int64 start_usec = Env::Default()->NowMicros();
|
||||
// Don't propagate dma_ok over gRPC.
|
||||
RecvTensorRequest* req_copy = nullptr;
|
||||
if (request->dma_ok()) {
|
||||
req_copy = new RecvTensorRequest;
|
||||
*req_copy = *request;
|
||||
req_copy->set_dma_ok(false);
|
||||
}
|
||||
// Type-specialized logging for this method.
|
||||
StatusCallback logging_callback = [this, request, req_copy, response, done,
|
||||
start_usec](Status s) {
|
||||
if (logger_->LoggingActive()) {
|
||||
int64 end_usec = Env::Default()->NowMicros();
|
||||
int64 step_id = request->step_id();
|
||||
int64 bytes = response->tensor().ByteSize();
|
||||
int64 send_start_usec = start_usec;
|
||||
// If a send start time was reported by the other side, use
|
||||
// that instead. Maybe we should mark the display if we're using
|
||||
// our local time instead of the remote start time?
|
||||
if (response->send_start_micros()) {
|
||||
// send_start_micros is the timestamp taken when the remote
|
||||
// machine began to send the RecvTensor response.
|
||||
// Due to clock skew between source and dest machines, it is
|
||||
// possible that send_start_micros can be larger than end_usec or
|
||||
// less than start_usec.
|
||||
// To respect causality, we enforce the invariants that the RecvTensor
|
||||
// response can not have been sent before the RecvTensor request, and
|
||||
// must have been sent before it was received.
|
||||
send_start_usec = std::max(start_usec, response->send_start_micros());
|
||||
send_start_usec = std::min(send_start_usec, end_usec - 1);
|
||||
}
|
||||
const string& key = request->rendezvous_key();
|
||||
std::vector<string> key_parts = str_util::Split(key, ';');
|
||||
if (key_parts.size() != 5) {
|
||||
LOG(WARNING) << "Bad key: " << key;
|
||||
} else {
|
||||
logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
|
||||
key_parts[3], // tensor name
|
||||
key_parts[0], // src_device
|
||||
key_parts[2], // dst_device
|
||||
bytes);
|
||||
}
|
||||
}
|
||||
VLOG(2) << "done callback, req: " << request->DebugString()
|
||||
<< " response " << response->DebugString();
|
||||
delete req_copy;
|
||||
done(s);
|
||||
};
|
||||
|
||||
IssueRequest(req_copy ? req_copy : request, response,
|
||||
&grpc::WorkerService::Stub::AsyncRecvTensor, logging_callback,
|
||||
call_opts);
|
||||
}
|
||||
|
||||
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
||||
StatusCallback done) override {
|
||||
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncLogging,
|
||||
done);
|
||||
}
|
||||
|
||||
void TracingAsync(const TracingRequest* request, TracingResponse* response,
|
||||
StatusCallback done) override {
|
||||
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncTracing,
|
||||
done);
|
||||
}
|
||||
|
||||
private:
|
||||
template <class RequestMessage, class ResponseMessage>
|
||||
using AsyncMethod =
|
||||
std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseMessage>> (
|
||||
grpc::WorkerService::Stub::*)(::grpc::ClientContext*,
|
||||
const RequestMessage&,
|
||||
::grpc::CompletionQueue*);
|
||||
|
||||
// Utility method for issuing a generic asynchronous request. The
|
||||
// given callback, `done`, will be called when the RPC completes.
|
||||
template <class RequestMessage, class ResponseMessage>
|
||||
void IssueRequest(const RequestMessage* request, ResponseMessage* response,
|
||||
AsyncMethod<RequestMessage, ResponseMessage> async_method,
|
||||
StatusCallback done, CallOptions* call_opts = nullptr) {
|
||||
::grpc::ClientContext* context = new ::grpc::ClientContext;
|
||||
if (call_opts) {
|
||||
call_opts->SetCancelCallback([context]() { context->TryCancel(); });
|
||||
}
|
||||
auto rpc = (stub_.get()->*async_method)(context, *request, cq_).release();
|
||||
GrpcClientCQTag* tag =
|
||||
new GrpcClientCQTag(context, [rpc, done, call_opts](Status s) {
|
||||
if (call_opts) {
|
||||
call_opts->ClearCancelCallback();
|
||||
}
|
||||
delete rpc;
|
||||
done(s);
|
||||
});
|
||||
rpc->Finish(response, tag->status(), tag);
|
||||
}
|
||||
|
||||
std::unique_ptr<grpc::WorkerService::Stub> stub_;
|
||||
::grpc::CompletionQueue* cq_;
|
||||
|
||||
// Support for logging.
|
||||
WorkerCacheLogger* logger_;
|
||||
bool retry_unavailable_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
|
||||
};
|
||||
|
||||
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
||||
::grpc::CompletionQueue* completion_queue,
|
||||
WorkerCacheLogger* logger) {
|
||||
return new GrpcRemoteWorker(channel, completion_queue, logger);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
38
tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
Normal file
38
tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
|
||||
namespace grpc {
|
||||
class CompletionQueue;
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class WorkerCacheLogger;
|
||||
class WorkerInterface;
|
||||
|
||||
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
|
||||
::grpc::CompletionQueue* completion_queue,
|
||||
WorkerCacheLogger* logger);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
|
116
tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
Normal file
116
tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
Normal file
@ -0,0 +1,116 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "external/grpc/include/grpc++/grpc++.h"
|
||||
#include "external/grpc/include/grpc++/security/credentials.h"
|
||||
#include "external/grpc/include/grpc++/server_builder.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_session.h"
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void StartTensorFlowServer(const GrpcServerOptions& options) {
|
||||
// The Thread destructor waits until all the thread terminates is
|
||||
// done (i.e. forever).
|
||||
std::unique_ptr<Thread> launcher_thread(Env::Default()->StartThread(
|
||||
ThreadOptions(), "TF_service_launcher", [options]() {
|
||||
// Configure the MasterEnv and WorkerEnv, which provide service-global
|
||||
// context for the master and worker services, respectively.
|
||||
|
||||
// The master and worker share the same worker cache (for RPC
|
||||
// connections to other workers) and devices (so that the master
|
||||
// may run some ops locally as a "client" device). The master
|
||||
// requires a device to serve as a "client device", so that remote
|
||||
// devices can copy the feeds from the master.
|
||||
MasterEnv master_env;
|
||||
WorkerEnv worker_env;
|
||||
master_env.env = Env::Default();
|
||||
worker_env.env = Env::Default();
|
||||
|
||||
// Configure shared devices between master and worker.
|
||||
string name_prefix =
|
||||
strings::StrCat("/job:", options.job_name, "/replica:0", "/task:",
|
||||
options.task_index);
|
||||
DeviceFactory::AddDevices(options.default_session_options, name_prefix,
|
||||
&master_env.local_devices);
|
||||
worker_env.device_mgr = new DeviceMgr(master_env.local_devices);
|
||||
string unused;
|
||||
CHECK(DeviceNameUtils::SplitDeviceName(
|
||||
master_env.local_devices[0]->name(), &worker_env.worker_name,
|
||||
&unused));
|
||||
|
||||
GrpcChannelCache* channel_cache =
|
||||
NewGrpcChannelCache(options.channel_spec);
|
||||
int port;
|
||||
const std::vector<string> host_port =
|
||||
str_util::Split(channel_cache->TranslateTask(name_prefix), ':');
|
||||
CHECK(str_util::NumericParse32(host_port[1], &port));
|
||||
|
||||
worker_env.worker_cache = NewGrpcWorkerCache(channel_cache);
|
||||
|
||||
// Finish setting up master environment.
|
||||
master_env.ops = OpRegistry::Global();
|
||||
master_env.worker_cache = worker_env.worker_cache;
|
||||
master_env.master_session_factory = internal::NewMasterSession;
|
||||
|
||||
// Finish setting up worker environment.
|
||||
worker_env.graph_mgr = new GraphMgr(&worker_env);
|
||||
worker_env.rendezvous_mgr = new RpcRendezvousMgr(&worker_env);
|
||||
worker_env.compute_pool = ComputePool(options.default_session_options);
|
||||
|
||||
// Build the gRPC server that will host both the master and the
|
||||
// worker services.
|
||||
::grpc::ServerBuilder builder;
|
||||
builder.AddListeningPort(strings::StrCat("0.0.0.0:", port),
|
||||
::grpc::InsecureServerCredentials());
|
||||
auto master_service = NewGrpcMasterService(&master_env, &builder);
|
||||
auto worker_service = NewGrpcWorkerService(&worker_env, &builder);
|
||||
auto server_ = builder.BuildAndStart();
|
||||
|
||||
// Start threads to handle the incoming RPCs for the master and worker.
|
||||
// NOTE(mrry): The Thread destructor waits until the thread terminates
|
||||
// (i.e. forever in this case).
|
||||
std::unique_ptr<Thread> master_thread(Env::Default()->StartThread(
|
||||
ThreadOptions(), "TF_master_service",
|
||||
[master_service]() { master_service->HandleRPCsLoop(); }));
|
||||
std::unique_ptr<Thread> worker_thread(Env::Default()->StartThread(
|
||||
ThreadOptions(), "TF_worker_service",
|
||||
[worker_service]() { worker_service->HandleRPCsLoop(); }));
|
||||
}));
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
53
tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
Normal file
53
tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Defines the configuration for a single task (typically a process)
|
||||
// that is part of a gRPC-based TensorFlow cluster.
|
||||
struct GrpcServerOptions {
|
||||
// This identity of the job to which this task belongs. The names
|
||||
// of the devices in this task will be prefixed with
|
||||
// "/job:<job_name>/task:<task_index>"
|
||||
string job_name;
|
||||
int32 task_index = 0;
|
||||
|
||||
// A channel specification, which defines (i) the set of jobs that
|
||||
// comprise the cluster, and (ii) within each job, the endpoints
|
||||
// exposed by each task. NOTE: This spec also defines the endpoint
|
||||
// on which this task will listen.
|
||||
GrpcChannelSpec channel_spec;
|
||||
|
||||
// SessionOptions that will be used as defaults when configuring
|
||||
// sessions in this task. `default_session_options.target` is
|
||||
// ignored.
|
||||
SessionOptions default_session_options;
|
||||
};
|
||||
|
||||
// Starts a gRPC-based TensorFlow server with the given options.
|
||||
// This function will not return.
|
||||
void StartTensorFlowServer(const GrpcServerOptions& options);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
|
233
tensorflow/core/distributed_runtime/rpc/grpc_session.cc
Normal file
233
tensorflow/core/distributed_runtime/rpc/grpc_session.cc
Normal file
@ -0,0 +1,233 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/common_runtime/session_factory.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const size_t kSchemePrefix = sizeof("grpc://") - 1;
|
||||
|
||||
GrpcSession::GrpcSession(const SessionOptions& options)
|
||||
: options_(options),
|
||||
master_(NewGrpcMaster(
|
||||
NewHostPortGrpcChannel(options.target.substr(kSchemePrefix)))),
|
||||
current_graph_version_(-1) {}
|
||||
|
||||
GrpcSession::~GrpcSession() {}
|
||||
|
||||
namespace {
|
||||
// Re-encodes constant represented in tensor proto into
|
||||
// tensor_content, which is slightly better (less copies and lower peak
|
||||
// memory usage) when used with rpc subsystems.
|
||||
void ReEncodeConsts(GraphDef* gdef) {
|
||||
for (NodeDef& ndef : *(gdef->mutable_node())) {
|
||||
if (ndef.op() == "Const") {
|
||||
TensorProto* proto = nullptr;
|
||||
for (auto& attr : *ndef.mutable_attr()) {
|
||||
if (attr.first == "value") {
|
||||
proto = attr.second.mutable_tensor();
|
||||
}
|
||||
}
|
||||
if (proto != nullptr && proto->tensor_content().empty() &&
|
||||
proto->ByteSize() > 64) {
|
||||
// If the constant is encoded with repeated proto fields and
|
||||
// it is moderate large, we re-encode it in tensor_content as
|
||||
// a Cord. This is mildly helpful for reducing the peak memory
|
||||
// usage on the server side where GraphDef/NodeDef are copied
|
||||
// quite often.
|
||||
Tensor parsed(proto->dtype());
|
||||
if (parsed.FromProto(*proto)) {
|
||||
parsed.AsProtoTensorContent(proto);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status GrpcSession::Create(const GraphDef& graph) {
|
||||
if (!handle_.empty()) {
|
||||
return errors::InvalidArgument("A session is alive.");
|
||||
}
|
||||
CreateSessionRequest req;
|
||||
*req.mutable_config() = options_.config;
|
||||
*req.mutable_graph_def() = graph;
|
||||
ReEncodeConsts(req.mutable_graph_def());
|
||||
CreateSessionResponse resp;
|
||||
Status s = master_->CreateSession(&req, &resp);
|
||||
if (s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
swap(handle_, *(resp.mutable_session_handle()));
|
||||
current_graph_version_ = resp.graph_version();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status GrpcSession::Extend(const GraphDef& graph) {
|
||||
if (handle_.empty()) {
|
||||
// Session was unitialized, so simply initialize the session with 'graph'.
|
||||
return Create(graph);
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
ExtendSessionRequest req;
|
||||
req.set_session_handle(handle_);
|
||||
*req.mutable_graph_def() = graph;
|
||||
req.set_current_graph_version(current_graph_version_);
|
||||
ExtendSessionResponse resp;
|
||||
Status s = master_->ExtendSession(&req, &resp);
|
||||
if (s.ok()) {
|
||||
current_graph_version_ = resp.new_graph_version();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
std::vector<Tensor>* outputs) {
|
||||
// Convert to proto
|
||||
RunStepRequest req;
|
||||
RunStepResponse resp;
|
||||
|
||||
for (const auto& it : inputs) {
|
||||
Tensor input_tensor = it.second;
|
||||
auto feed = req.add_feed();
|
||||
feed->set_name(it.first);
|
||||
TensorProto* proto = feed->mutable_tensor();
|
||||
input_tensor.AsProtoTensorContent(proto);
|
||||
}
|
||||
|
||||
// Build an index from fetch tensor name to offset.
|
||||
std::unordered_map<string, int> output_name_to_offset;
|
||||
for (const string& output_name : output_names) {
|
||||
req.add_fetch(output_name);
|
||||
output_name_to_offset.insert(
|
||||
std::make_pair(output_name, output_name_to_offset.size()));
|
||||
}
|
||||
for (const string& target : target_nodes) {
|
||||
req.add_target(target);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(RunProto(&req, &resp));
|
||||
|
||||
if (!output_names.empty()) {
|
||||
outputs->resize(output_names.size());
|
||||
}
|
||||
|
||||
// Convert response back to Tensors in the correct order.
|
||||
for (const NamedTensorProto& tensor : resp.tensor()) {
|
||||
auto fetch_it = output_name_to_offset.find(tensor.name());
|
||||
if (fetch_it == output_name_to_offset.end()) {
|
||||
return errors::Internal("Received response for unrequested fetch: ",
|
||||
tensor.name());
|
||||
}
|
||||
|
||||
Tensor output;
|
||||
if (!output.FromProto(tensor.tensor())) {
|
||||
return errors::InvalidArgument("Could not parse returned proto for ",
|
||||
tensor.name());
|
||||
}
|
||||
|
||||
(*outputs)[fetch_it->second] = output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrpcSession::RunProto(RunStepRequest* req, RunStepResponse* resp) {
|
||||
if (handle_.empty()) {
|
||||
return errors::InvalidArgument("A session is not created yet....");
|
||||
}
|
||||
|
||||
req->set_session_handle(handle_);
|
||||
return master_->RunStep(req, resp);
|
||||
}
|
||||
|
||||
Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
string* handle) {
|
||||
return errors::Internal("Partial run is not supported for remote session.");
|
||||
}
|
||||
|
||||
Status GrpcSession::PRun(const string& handle,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs) {
|
||||
return errors::Internal("Partial run is not supported for remote session.");
|
||||
}
|
||||
|
||||
Status GrpcSession::Close() {
|
||||
if (handle_.empty()) {
|
||||
return errors::InvalidArgument("A session is not created yet....");
|
||||
}
|
||||
CloseSessionRequest req;
|
||||
req.set_session_handle(handle_);
|
||||
handle_.clear();
|
||||
CloseSessionResponse resp;
|
||||
return master_->CloseSession(&req, &resp);
|
||||
}
|
||||
|
||||
std::vector<DeviceAttributes> GrpcSession::ListDevices() {
|
||||
std::vector<DeviceAttributes> devices;
|
||||
|
||||
ListDevicesRequest req;
|
||||
ListDevicesResponse resp;
|
||||
Status s = master_->ListDevices(&req, &resp);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not list devices: " << s;
|
||||
return devices;
|
||||
}
|
||||
|
||||
for (const auto& device_attr : resp.local_device()) {
|
||||
devices.push_back(device_attr);
|
||||
}
|
||||
for (const auto& device_attr : resp.remote_device()) {
|
||||
devices.push_back(device_attr);
|
||||
}
|
||||
|
||||
return devices;
|
||||
}
|
||||
|
||||
class GrpcSessionFactory : public SessionFactory {
|
||||
public:
|
||||
bool AcceptsOptions(const SessionOptions& options) override {
|
||||
return StringPiece(options.target).starts_with("grpc://");
|
||||
}
|
||||
|
||||
Session* NewSession(const SessionOptions& options) override {
|
||||
return new GrpcSession(options);
|
||||
}
|
||||
};
|
||||
|
||||
class GrpcSessionRegistrar {
|
||||
public:
|
||||
GrpcSessionRegistrar() {
|
||||
SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
|
||||
}
|
||||
};
|
||||
static GrpcSessionRegistrar registrar;
|
||||
|
||||
} // namespace tensorflow
|
97
tensorflow/core/distributed_runtime/rpc/grpc_session.h
Normal file
97
tensorflow/core/distributed_runtime/rpc/grpc_session.h
Normal file
@ -0,0 +1,97 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class MasterInterface;
|
||||
|
||||
// A Session instance lets the caller drive a TensorFlow graph
|
||||
// computation on potentially remote sets of devices. This is a thin
|
||||
// wrapper around tensorflow::grpc::MasterService.
|
||||
//
|
||||
// Multiple threads must synchronize their accesses to a single
|
||||
// session.
|
||||
class GrpcSession : public Session {
|
||||
public:
|
||||
// Do not use; just present for easier swig wrapping.
|
||||
explicit GrpcSession(const SessionOptions& options);
|
||||
|
||||
~GrpcSession() override;
|
||||
|
||||
// Creates a session with the "target". The session carries out
|
||||
// the graph computation defined by "graph", and will have version
|
||||
// number "initial_version".
|
||||
Status Create(const GraphDef& graph) override;
|
||||
|
||||
Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
std::vector<Tensor>* outputs) override;
|
||||
|
||||
Status Extend(const GraphDef& graph) override;
|
||||
Status Close() override;
|
||||
|
||||
// NOTE: This API is still experimental and may change.
|
||||
::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
string* handle) override;
|
||||
|
||||
// NOTE: This API is still experimental and may change.
|
||||
::tensorflow::Status PRun(
|
||||
const string& handle,
|
||||
const std::vector<std::pair<string, Tensor> >& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs) override;
|
||||
|
||||
std::vector<DeviceAttributes> ListDevices();
|
||||
|
||||
private:
|
||||
SessionOptions options_;
|
||||
std::unique_ptr<MasterInterface> master_;
|
||||
mutex mu_;
|
||||
|
||||
// handle_ returned by the master to identify this session.
|
||||
string handle_;
|
||||
|
||||
// The current version of the graph.
|
||||
int64 current_graph_version_ GUARDED_BY(mu_);
|
||||
|
||||
Status RunProto(RunStepRequest* req, RunStepResponse* resp);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcSession);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
|
750
tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
Normal file
750
tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
Normal file
@ -0,0 +1,750 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/default_device.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/port.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static SessionOptions Devices(int num_cpus, int num_gpus) {
|
||||
SessionOptions result;
|
||||
(*result.config.mutable_device_count())["CPU"] = num_cpus;
|
||||
(*result.config.mutable_device_count())["GPU"] = num_gpus;
|
||||
return result;
|
||||
}
|
||||
|
||||
void CreateGraphDef(GraphDef* graph_def, string node_names[3]) {
|
||||
Graph graph(OpRegistry::Global());
|
||||
|
||||
Tensor a_tensor(DT_FLOAT, TensorShape({1, 2}));
|
||||
test::FillValues<float>(&a_tensor, {1, 2});
|
||||
Node* a = test::graph::Constant(&graph, a_tensor);
|
||||
node_names[0] = a->name();
|
||||
|
||||
Tensor b_tensor(DT_FLOAT, TensorShape({2, 1}));
|
||||
test::FillValues<float>(&b_tensor, {2, 1});
|
||||
Node* b = test::graph::Constant(&graph, b_tensor);
|
||||
node_names[1] = b->name();
|
||||
|
||||
Node* c = test::graph::Matmul(&graph, a, b, false, false);
|
||||
node_names[2] = c->name();
|
||||
|
||||
test::graph::ToGraphDef(&graph, graph_def);
|
||||
}
|
||||
|
||||
// Asserts that "val" is a single float tensor. The only float is
|
||||
// "expected_val".
|
||||
static void IsSingleFloatValue(const Tensor& val, float expected_val) {
|
||||
ASSERT_EQ(val.dtype(), DT_FLOAT);
|
||||
ASSERT_EQ(val.NumElements(), 1);
|
||||
ASSERT_EQ(val.flat<float>()(0), expected_val);
|
||||
}
|
||||
|
||||
static SessionOptions Options(const string& target, int placement_period) {
|
||||
SessionOptions options;
|
||||
// NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target
|
||||
// string.
|
||||
options.target = strings::StrCat("grpc://", target);
|
||||
options.config.set_placement_period(placement_period);
|
||||
return options;
|
||||
}
|
||||
|
||||
static Session* NewRemote(const SessionOptions& options) {
|
||||
return CHECK_NOTNULL(NewSession(options));
|
||||
}
|
||||
|
||||
TEST(GrpcSessionTest, BasicNonProtoAPI) {
|
||||
GraphDef graph;
|
||||
string node_names[3];
|
||||
// c = a * b
|
||||
CreateGraphDef(&graph, node_names);
|
||||
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
|
||||
|
||||
std::unique_ptr<Session> session(
|
||||
NewRemote(Options(cluster->targets()[0], 1)));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
for (int iters = 0; iters < 25; ++iters) {
|
||||
TF_CHECK_OK(session->Create(graph));
|
||||
{
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
TF_CHECK_OK(session->Run(inputs, {}, {}, {}));
|
||||
}
|
||||
{
|
||||
// Just run to target node
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
std::vector<string> targets = {node_names[2]};
|
||||
TF_CHECK_OK(session->Run(inputs, {}, targets, nullptr));
|
||||
}
|
||||
{
|
||||
// Run to a target node and a real tensor
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
std::vector<string> names = {node_names[2] + ":0"};
|
||||
std::vector<string> targets = {node_names[1]};
|
||||
std::vector<Tensor> outputs;
|
||||
TF_CHECK_OK(session->Run(inputs, names, targets, &outputs));
|
||||
ASSERT_TRUE(outputs[0].IsInitialized());
|
||||
ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
|
||||
}
|
||||
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
|
||||
GraphDef graph;
|
||||
string node_names[3];
|
||||
// c = a * b
|
||||
CreateGraphDef(&graph, node_names);
|
||||
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
|
||||
|
||||
std::unique_ptr<Session> session(
|
||||
NewRemote(Options(cluster->targets()[0], 1)));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
ASSERT_TRUE(session->Create(graph).ok());
|
||||
|
||||
// Test that the order of the output names matches the order of the
|
||||
// returned Tensors.
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
std::vector<string> names = {node_names[2] + ":0", node_names[0] + ":0",
|
||||
node_names[1] + ":0"};
|
||||
|
||||
std::vector<string> target_ops = {node_names[1]};
|
||||
std::vector<Tensor> outputs;
|
||||
ASSERT_TRUE(session->Run(inputs, names, target_ops, &outputs).ok());
|
||||
ASSERT_TRUE(outputs[0].IsInitialized());
|
||||
ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
|
||||
ASSERT_TRUE(outputs[1].IsInitialized());
|
||||
ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
|
||||
ASSERT_TRUE(outputs[2].IsInitialized());
|
||||
ASSERT_EQ(2.0, outputs[2].flat<float>()(0));
|
||||
ASSERT_TRUE(session->Close().ok());
|
||||
}
|
||||
|
||||
TEST(GrpcSessionTest, NonLocalWithFilters) {
|
||||
GraphDef graph;
|
||||
string node_names[3];
|
||||
// c = a * b
|
||||
CreateGraphDef(&graph, node_names);
|
||||
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
|
||||
|
||||
SessionOptions options;
|
||||
options.target = strings::StrCat("grpc://", cluster->targets()[0]);
|
||||
options.config.add_device_filters(cluster->devices()[0].name());
|
||||
|
||||
std::unique_ptr<Session> session(NewRemote(options));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
{
|
||||
GraphDef graph_copy(graph);
|
||||
graph::SetDefaultDevice(cluster->devices()[0].name(), &graph_copy);
|
||||
TF_CHECK_OK(session->Create(graph_copy));
|
||||
TF_CHECK_OK(session->Run({}, {}, {}, nullptr));
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
{
|
||||
GraphDef graph_copy(graph);
|
||||
graph::SetDefaultDevice(cluster->devices()[1].name(), &graph_copy);
|
||||
TF_CHECK_OK(session->Create(graph_copy));
|
||||
auto status = session->Run({}, {}, {}, nullptr);
|
||||
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
}
|
||||
|
||||
// A = [3 2; -1 0]; x = rand(2, 1); We want to compute the largest
|
||||
// eigenvalue for A, which is 2.0. Iteratively, we do
|
||||
// repeat x = y / y.norm(); y = A * x; end
|
||||
// At the end, we expect "lambda" converges to 2.0.
|
||||
void FindMaxEigen(const string& target) {
|
||||
Graph graph(OpRegistry::Global());
|
||||
|
||||
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
|
||||
// Store rows [3, 2] and [-1, 0] in row major format.
|
||||
test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
|
||||
Node* a = test::graph::Constant(&graph, a_tensor);
|
||||
|
||||
// x is from the feed.
|
||||
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
|
||||
test::FillValues<float>(&x_tensor, {0, 0});
|
||||
Node* x = test::graph::Constant(&graph, x_tensor);
|
||||
|
||||
// y = A * x
|
||||
Node* y = test::graph::Matmul(&graph, a, x, false, false);
|
||||
|
||||
// y2 = y.^2
|
||||
Node* y2 = test::graph::Unary(&graph, "Square", y);
|
||||
|
||||
// const tensor for reduction
|
||||
Tensor rdim_tensor(DT_INT32, TensorShape({}));
|
||||
rdim_tensor.scalar<int32>()() = 0;
|
||||
Node* rdim = test::graph::Constant(&graph, rdim_tensor);
|
||||
|
||||
// y2_sum = sum(y2)
|
||||
Node* y2_sum = test::graph::Reduce(&graph, "Sum", y2, rdim);
|
||||
|
||||
// y_norm = sqrt(y2_sum)
|
||||
Node* y_norm = test::graph::Unary(&graph, "Sqrt", y2_sum);
|
||||
|
||||
// y_normalized = y ./ y_norm
|
||||
Node* y_normalized = test::graph::Binary(&graph, "Div", y, y_norm);
|
||||
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&graph, &def);
|
||||
|
||||
std::unique_ptr<Session> session(NewRemote(Options(target, 1)));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_CHECK_OK(session->Create(def));
|
||||
|
||||
// Setup feeds and fetches.
|
||||
float lambda;
|
||||
Tensor feed_value(DT_FLOAT, TensorShape({2, 1}));
|
||||
feed_value.matrix<float>()(0, 0) = -3.1415;
|
||||
feed_value.matrix<float>()(1, 0) = +2.7183;
|
||||
|
||||
for (int i = 0; i < 25; ++i) {
|
||||
std::vector<Tensor> outputs;
|
||||
TF_CHECK_OK(session->Run({{x->name(), feed_value}},
|
||||
{y->name(), y_normalized->name()}, {}, &outputs));
|
||||
const Tensor& y = outputs[0];
|
||||
const Tensor& y_normalized = outputs[1];
|
||||
// Print out lambda, x, and y.
|
||||
CHECK_EQ(2, feed_value.NumElements());
|
||||
CHECK_EQ(2, y.NumElements());
|
||||
lambda = y.flat<float>()(0) / feed_value.flat<float>()(0);
|
||||
printf("%06d lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]\n", i,
|
||||
lambda, feed_value.flat<float>()(0), feed_value.flat<float>()(1),
|
||||
y.flat<float>()(0), y.flat<float>()(1));
|
||||
// Copies y_normalized to *x.
|
||||
feed_value = y_normalized;
|
||||
}
|
||||
EXPECT_NEAR(2.0, lambda, 1e-6);
|
||||
}
|
||||
|
||||
TEST(FindMaxEigenTest, RemoteDevice) {
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster);
|
||||
FindMaxEigen(cluster->targets()[0]);
|
||||
}
|
||||
|
||||
void SetDevice(GraphDef* graph, const string& name, const string& dev) {
|
||||
for (int i = 0; i < graph->node_size(); ++i) {
|
||||
if (graph->node(i).name() == name) {
|
||||
graph->mutable_node(i)->set_device(dev);
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Name '" << name << "' not found.";
|
||||
}
|
||||
|
||||
TEST(GrpcSessionTest, MultiDevices) {
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
|
||||
|
||||
Graph graph(OpRegistry::Global());
|
||||
const int kSize = 1048576;
|
||||
|
||||
// c = a * b = 2 * 3 * kSize
|
||||
Tensor a_tensor(DT_FLOAT, TensorShape({1, kSize}));
|
||||
Tensor b_tensor(DT_FLOAT, TensorShape({kSize, 1}));
|
||||
for (int i = 0; i < kSize; ++i) {
|
||||
a_tensor.flat<float>()(i) = 2;
|
||||
b_tensor.flat<float>()(i) = 3;
|
||||
}
|
||||
Node* a = test::graph::Constant(&graph, a_tensor);
|
||||
Node* b = test::graph::Constant(&graph, b_tensor);
|
||||
Node* c = test::graph::Matmul(&graph, a, b, false, false);
|
||||
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&graph, &def);
|
||||
|
||||
// In this test, we force each node (a, b, c) on every possible device.
|
||||
// We test all possible cases.
|
||||
for (const auto& a_dev : cluster->devices()) {
|
||||
for (const auto& b_dev : cluster->devices()) {
|
||||
for (const auto& c_dev : cluster->devices()) {
|
||||
LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name()
|
||||
<< " c: " << c_dev.name();
|
||||
|
||||
SetDevice(&def, a->name(), a_dev.name());
|
||||
SetDevice(&def, b->name(), b_dev.name());
|
||||
SetDevice(&def, c->name(), c_dev.name());
|
||||
|
||||
std::unique_ptr<Session> session(
|
||||
NewRemote(Options(cluster->targets()[0], 1000)));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_CHECK_OK(session->Create(def));
|
||||
{
|
||||
std::vector<Tensor> outputs;
|
||||
TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs));
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
IsSingleFloatValue(outputs[0], 6.0 * kSize);
|
||||
}
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GrpcSessionTest, MultiDevices_String) {
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster));
|
||||
std::unique_ptr<Session> session(
|
||||
NewRemote(Options(cluster->targets()[0], 1000)));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
// b = a
|
||||
Graph graph(OpRegistry::Global());
|
||||
Tensor a_tensor(DT_STRING, TensorShape({2, 2}));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
a_tensor.flat<string>()(i) = "hello, world";
|
||||
}
|
||||
Node* a = test::graph::Constant(&graph, a_tensor);
|
||||
Node* b = test::graph::Identity(&graph, a);
|
||||
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&graph, &def);
|
||||
|
||||
// In this test, we force each node (a, b) on every possible device.
|
||||
// We test all possible cases.
|
||||
for (const auto& a_dev : cluster->devices()) {
|
||||
for (const auto& b_dev : cluster->devices()) {
|
||||
LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name();
|
||||
SetDevice(&def, a->name(), a_dev.name());
|
||||
SetDevice(&def, b->name(), b_dev.name());
|
||||
|
||||
TF_CHECK_OK(session->Create(def));
|
||||
{
|
||||
std::vector<Tensor> outputs;
|
||||
Status s = session->Run({}, {b->name()}, {}, &outputs);
|
||||
if (s.ok()) {
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
ASSERT_EQ(outputs[0].dtype(), DT_STRING);
|
||||
ASSERT_EQ(outputs[0].NumElements(), 4);
|
||||
for (int i = 0; i < outputs[0].NumElements(); ++i) {
|
||||
EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world");
|
||||
}
|
||||
} else {
|
||||
LOG(ERROR) << "Error: " << s;
|
||||
ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) ||
|
||||
(b_dev.device_type() == DEVICE_GPU));
|
||||
ASSERT_FALSE(s.ok());
|
||||
}
|
||||
}
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GrpcSessionTest, SendRecv_Node_Naming) {
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 3, &cluster));
|
||||
std::unique_ptr<Session> session(
|
||||
NewRemote(Options(cluster->targets()[0], 1)));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
// This test case needs at least 3 devices.
|
||||
CHECK_GE(cluster->devices().size(), 3);
|
||||
const DeviceAttributes& src = cluster->devices()[0];
|
||||
const DeviceAttributes& dst0 = cluster->devices()[1];
|
||||
const DeviceAttributes& dst1 = cluster->devices()[2];
|
||||
LOG(INFO) << "src = " << src.name() << " dst0 = " << dst0.name()
|
||||
<< " dst1 = " << dst1.name();
|
||||
|
||||
// Within the same session, we compute two subgraphs:
|
||||
// 1) a on 'src' sends to b on 'dst0';
|
||||
// 2) a on 'src' sends to c on 'dst1'.
|
||||
Graph graph(OpRegistry::Global());
|
||||
Tensor a_tensor(DT_FLOAT, TensorShape({1, 1}));
|
||||
a_tensor.flat<float>()(0) = 100;
|
||||
Node* a = test::graph::Constant(&graph, a_tensor);
|
||||
Node* b = test::graph::Identity(&graph, a);
|
||||
Node* c = test::graph::Identity(&graph, a);
|
||||
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&graph, &def);
|
||||
|
||||
// The base graph have a, b, c, assigned to devices explicitly.
|
||||
SetDevice(&def, a->name(), src.name());
|
||||
SetDevice(&def, b->name(), dst0.name());
|
||||
SetDevice(&def, c->name(), dst1.name());
|
||||
TF_CHECK_OK(session->Create(def));
|
||||
|
||||
// Run subgraph a -> b, and fetch b.
|
||||
{
|
||||
std::vector<Tensor> outputs;
|
||||
TF_CHECK_OK(session->Run({}, {b->name()}, {}, &outputs));
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
IsSingleFloatValue(outputs[0], 100);
|
||||
}
|
||||
|
||||
// Run subgraph a -> c, and fetch c.
|
||||
{
|
||||
std::vector<Tensor> outputs;
|
||||
TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs));
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
IsSingleFloatValue(outputs[0], 100);
|
||||
}
|
||||
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
|
||||
TEST(GrpcSessionTest, Error) {
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
|
||||
const string& master = cluster->targets()[0];
|
||||
const string& dev_a = cluster->devices()[0].name();
|
||||
const string& dev_b = cluster->devices()[1].name();
|
||||
LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b;
|
||||
GraphDef gdef;
|
||||
std::vector<string> fetches;
|
||||
{
|
||||
Graph g(OpRegistry::Global());
|
||||
|
||||
// a2 = a + error(a)
|
||||
//
|
||||
// Subgraph for "a" fails. The master will cancel the subgraph for
|
||||
// "b" and then returns the Session::Run.
|
||||
auto a = test::graph::Constant(&g, Tensor());
|
||||
a->set_assigned_device_name(dev_a);
|
||||
auto a_err = test::graph::Error(&g, a, "fantasia!");
|
||||
a_err->set_assigned_device_name(dev_a);
|
||||
auto a2 = test::graph::Add(&g, a, a_err);
|
||||
a2->set_assigned_device_name(dev_a);
|
||||
fetches.push_back(a2->name());
|
||||
|
||||
// b2 = b + delay(b)
|
||||
//
|
||||
// Subgraph for "b" sleeps at the node "b_delay". When the sleep
|
||||
// finishes, the subgraph "b" will continue execution till it
|
||||
// notices that it is cancelled. Meanwhile, subgraph's executor
|
||||
// and its related state (registered ops) should still be alive.
|
||||
auto b = test::graph::Constant(&g, Tensor());
|
||||
b->set_assigned_device_name(dev_b);
|
||||
auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000));
|
||||
b_delay->set_assigned_device_name(dev_b);
|
||||
auto b2 = test::graph::Add(&g, b, b_delay);
|
||||
b2->set_assigned_device_name(dev_b);
|
||||
fetches.push_back(b2->name());
|
||||
test::graph::ToGraphDef(&g, &gdef);
|
||||
}
|
||||
std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
TF_CHECK_OK(session->Create(gdef));
|
||||
{
|
||||
Status status = session->Run({}, fetches, {}, nullptr);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
|
||||
}
|
||||
// session->Close() shall clean up all states related to the session->
|
||||
// E.g., deregisters subgraph with workers, etc.
|
||||
TF_CHECK_OK(session->Close());
|
||||
|
||||
// Sleep a bit so that most of asynchronous works finishes before
|
||||
// the test process finishes.
|
||||
Env::Default()->SleepForMicroseconds(2000000);
|
||||
}
|
||||
|
||||
TEST(SessionTest, SharedVar) {
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
|
||||
const string master = cluster->targets()[0];
|
||||
CHECK_EQ(cluster->devices().size(), 1);
|
||||
|
||||
GraphDef gdef;
|
||||
string init_name;
|
||||
string inc_name;
|
||||
string get_name;
|
||||
{
|
||||
Graph g(OpRegistry::Global());
|
||||
Tensor one(DT_FLOAT, TensorShape({}));
|
||||
one.scalar<float>()() = 1.0;
|
||||
Node* var = test::graph::Var(&g, DT_FLOAT, one.shape());
|
||||
Node* init = test::graph::Assign(&g, var, test::graph::Constant(&g, one));
|
||||
init_name = init->name();
|
||||
Node* update = test::graph::Assign(
|
||||
&g, var, test::graph::Add(&g, var, test::graph::Constant(&g, one)));
|
||||
inc_name = update->name();
|
||||
get_name = var->name();
|
||||
test::graph::ToGraphDef(&g, &gdef);
|
||||
}
|
||||
|
||||
// Init a variable
|
||||
{
|
||||
Session* sess = NewRemote(Options(master, 1));
|
||||
TF_CHECK_OK(sess->Create(gdef));
|
||||
std::vector<std::pair<string, Tensor>> inp;
|
||||
TF_CHECK_OK(sess->Run(inp, {}, {init_name}, nullptr));
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
}
|
||||
|
||||
for (int rep = 1; rep < 10; ++rep) {
|
||||
// Update a variable
|
||||
{
|
||||
Session* sess = NewRemote(Options(master, 1));
|
||||
TF_CHECK_OK(sess->Create(gdef));
|
||||
std::vector<std::pair<string, Tensor>> inp;
|
||||
TF_CHECK_OK(sess->Run(inp, {}, {inc_name}, nullptr));
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
}
|
||||
|
||||
// Gets the variable's value.
|
||||
{
|
||||
Session* sess = NewRemote(Options(master, 1));
|
||||
TF_CHECK_OK(sess->Create(gdef));
|
||||
std::vector<std::pair<string, Tensor>> inp;
|
||||
std::vector<Tensor> ret;
|
||||
TF_CHECK_OK(sess->Run(inp, {get_name}, {}, &ret));
|
||||
ASSERT_EQ(ret.size(), 1);
|
||||
EXPECT_EQ(ret[0].scalar<float>()(), 1.0 * (1 + rep));
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CreateInvalidGraph(const string& graph_def_ascii,
|
||||
const string& error_substring) {
|
||||
GraphDef graph;
|
||||
CHECK(protobuf::TextFormat::ParseFromString(graph_def_ascii, &graph));
|
||||
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
|
||||
|
||||
std::unique_ptr<Session> session(
|
||||
NewRemote(Options(cluster->targets()[0], 1)));
|
||||
Status s = session->Create(graph);
|
||||
|
||||
ASSERT_FALSE(s.ok());
|
||||
EXPECT_NE(s.error_message().find(error_substring), string::npos);
|
||||
}
|
||||
|
||||
TEST(SessionTest, InvalidOpName) {
|
||||
CreateInvalidGraph(R"(
|
||||
node {
|
||||
name: 'a:b' op: 'Const'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'value' value {
|
||||
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
|
||||
float_val: [100] }
|
||||
} }
|
||||
}
|
||||
)",
|
||||
"Illegal op name");
|
||||
|
||||
CreateInvalidGraph(R"(
|
||||
node {
|
||||
name: 'a:0' op: 'Const'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'value' value {
|
||||
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
|
||||
float_val: [100] }
|
||||
} }
|
||||
}
|
||||
)",
|
||||
"Illegal op name");
|
||||
|
||||
CreateInvalidGraph(R"(
|
||||
node {
|
||||
name: '_a' op: 'Const'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'value' value {
|
||||
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
|
||||
float_val: [100] }
|
||||
} }
|
||||
}
|
||||
)",
|
||||
"Illegal op name");
|
||||
}
|
||||
|
||||
TEST(SessionTest, InvalidOpInputName) {
|
||||
CreateInvalidGraph(R"(
|
||||
node {
|
||||
name: 'a' op: 'const'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'value' value {
|
||||
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
|
||||
float_val: [100] }
|
||||
} }
|
||||
}
|
||||
node {
|
||||
name:'b' op:'MatMul' input:'a:first' input:'a'
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
|
||||
CreateInvalidGraph(R"(
|
||||
node {
|
||||
name: 'a' op: 'const'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'value' value {
|
||||
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
|
||||
float_val: [100] }
|
||||
} }
|
||||
}
|
||||
node {
|
||||
name:'b' op:'MatMul' input:'_a' input:'a'
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
|
||||
CreateInvalidGraph(R"(
|
||||
node {
|
||||
name: 'a' op: 'const'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'value' value {
|
||||
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
|
||||
float_val: [100] }
|
||||
} }
|
||||
}
|
||||
node {
|
||||
name:'b' op:'MatMul' input:'_a:0' input:'a'
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
|
||||
CreateInvalidGraph(R"(
|
||||
node {
|
||||
name: 'a' op: 'const'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'value' value {
|
||||
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
|
||||
float_val: [100] }
|
||||
} }
|
||||
}
|
||||
node {
|
||||
name:'b' op:'MatMul' input:'a:01' input:'a'
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
"Illegal op input name");
|
||||
}
|
||||
|
||||
TEST(SessionTest, ExtendValidation) {
|
||||
GraphDef graph;
|
||||
bool success = protobuf::TextFormat::ParseFromString(R"(
|
||||
node {
|
||||
name: 'a' op: 'Const'
|
||||
attr { key: 'dtype' value { type: DT_FLOAT } }
|
||||
attr { key: 'value' value {
|
||||
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
|
||||
float_val: [100] }
|
||||
} }
|
||||
}
|
||||
)",
|
||||
&graph);
|
||||
// NOTE(mrry): CHECK not done inline to avoid a compilation error in
|
||||
// open-source (due to a multi-line string in a macro argument).
|
||||
ASSERT_TRUE(success);
|
||||
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
|
||||
|
||||
std::unique_ptr<Session> session(
|
||||
NewRemote(Options(cluster->targets()[0], 1)));
|
||||
TF_CHECK_OK(session->Create(graph));
|
||||
|
||||
// 1. Fail with an unknown input name.
|
||||
GraphDef extension;
|
||||
success = protobuf::TextFormat::ParseFromString(R"(
|
||||
node {
|
||||
name:'b' op:'MatMul' input:'a:first' input:'a'
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
&extension);
|
||||
ASSERT_TRUE(success);
|
||||
|
||||
Status s = session->Extend(extension);
|
||||
ASSERT_FALSE(s.ok());
|
||||
EXPECT_NE(s.error_message().find("Illegal op input name"), string::npos);
|
||||
|
||||
// 2. Succeed with a valid node.
|
||||
success = protobuf::TextFormat::ParseFromString(R"(
|
||||
node {
|
||||
name:'b' op:'MatMul' input:'a' input:'a'
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
&extension);
|
||||
ASSERT_TRUE(success);
|
||||
TF_CHECK_OK(session->Extend(extension));
|
||||
|
||||
// 2. Fail with a duplicate node.
|
||||
success = protobuf::TextFormat::ParseFromString(R"(
|
||||
node {
|
||||
name:'b' op:'MatMul' input:'a' input:'a'
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
attr { key: 'transpose_a' value { b: false } }
|
||||
attr { key: 'transpose_b' value { b: false } }
|
||||
attr { key: '_kernel' value { s: 'eigen' } }
|
||||
}
|
||||
)",
|
||||
&extension);
|
||||
ASSERT_TRUE(success);
|
||||
s = session->Extend(extension);
|
||||
ASSERT_FALSE(s.ok());
|
||||
EXPECT_NE(s.error_message().find("'b', which was created by a previous call"),
|
||||
string::npos);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,98 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "external/grpc/include/grpc++/grpc++.h"
|
||||
#include "external/grpc/include/grpc++/security/credentials.h"
|
||||
#include "external/grpc/include/grpc++/server_builder.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
// This binary starts a TensorFlow server (master and worker).
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status ParseFlagsForTask(int argc, char* argv[], GrpcServerOptions* options) {
|
||||
string cluster_spec;
|
||||
const bool parse_result =
|
||||
ParseFlags(&argc, argv, {Flag("cluster_spec", &cluster_spec), //
|
||||
Flag("job_name", &options->job_name), //
|
||||
Flag("task_id", &options->task_index)});
|
||||
if (!parse_result) {
|
||||
return errors::InvalidArgument("Error parsing command-line flags");
|
||||
}
|
||||
|
||||
size_t my_num_tasks = 0;
|
||||
for (const string& job_str : str_util::Split(cluster_spec, ',')) {
|
||||
// Split each entry in the flag into 3 pieces, separated by "|".
|
||||
const std::vector<string> job_pieces = str_util::Split(job_str, '|');
|
||||
CHECK_EQ(2, job_pieces.size()) << job_str;
|
||||
const string& job = job_pieces[0];
|
||||
// Does a bit more validation of the tasks_per_replica.
|
||||
const StringPiece spec = job_pieces[1];
|
||||
// job_str is of form <job_name>|<host_ports>.
|
||||
const std::vector<string> host_ports = str_util::Split(spec, ';');
|
||||
size_t num_tasks = host_ports.size();
|
||||
if (job == options->job_name) {
|
||||
my_num_tasks = num_tasks;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
options->channel_spec.AddHostPortsJob(job, host_ports, num_tasks));
|
||||
LOG(INFO) << "Peer " << job << " " << num_tasks << " {"
|
||||
<< str_util::Join(host_ports, ", ") << "}";
|
||||
}
|
||||
if (my_num_tasks == 0) {
|
||||
return errors::InvalidArgument("Job name \"", options->job_name,
|
||||
"\" does not appear in the cluster spec");
|
||||
}
|
||||
if (options->task_index >= my_num_tasks) {
|
||||
return errors::InvalidArgument("Task index ", options->task_index,
|
||||
" is invalid (job \"", options->job_name,
|
||||
"\" contains ", my_num_tasks, " tasks");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
tensorflow::GrpcServerOptions options;
|
||||
tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &options);
|
||||
if (!s.ok()) {
|
||||
std::cerr << "ERROR: " << s.error_message() << std::endl;
|
||||
std::cerr << "Usage: " << argv[0]
|
||||
<< " --cluster_spec=SPEC --job_name=NAME --task_id=ID"
|
||||
<< std::endl;
|
||||
std::cerr << "Where:" << std::endl;
|
||||
std::cerr << " SPEC is <JOB>(,<JOB>)*" << std::endl;
|
||||
std::cerr << " JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*" << std::endl;
|
||||
std::cerr << " NAME is a valid job name ([a-z][0-9a-z]*)" << std::endl;
|
||||
std::cerr << " HOST is a hostname or IP address" << std::endl;
|
||||
std::cerr << " PORT is a port number" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
tensorflow::StartTensorFlowServer(options);
|
||||
}
|
@ -0,0 +1,123 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
|
||||
#include "external/grpc/include/grpc++/grpc++.h"
|
||||
#include "external/grpc/include/grpc++/security/credentials.h"
|
||||
#include "external/grpc/include/grpc++/server_builder.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_session.h"
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
// This binary starts a TensorFlow server (master and worker) for test purposes.
|
||||
namespace tensorflow {
|
||||
|
||||
struct GrpcTaskOptions {
|
||||
// This process belongs to the "job_name".
|
||||
string job_name;
|
||||
|
||||
// This process is the task-th task within the replica. 0th, 1st,
|
||||
// 2nd, etc.
|
||||
int32 task = 0;
|
||||
|
||||
// Specification of peers.
|
||||
GrpcChannelSpec channel_spec;
|
||||
|
||||
SessionOptions default_session_options;
|
||||
};
|
||||
|
||||
Status StartTensorFlowServer(const TaskOptions& task_options) {
|
||||
thread::ThreadPool* thread_pool =
|
||||
new thread::ThreadPool(Env::Default(), "server", 1);
|
||||
thread_pool->Schedule([argc, argv, task_options]() {
|
||||
// This process provides both the worker service and the master
|
||||
// service. We let these two services share the same channel cache
|
||||
// (rpc connections) and cpu devices (used by the master as the
|
||||
// client device). These client devices require a worker service
|
||||
// so that remote devices can copy the feeds from the client
|
||||
// device in the master.
|
||||
tensorflow::MasterEnv master_env;
|
||||
string name_prefix =
|
||||
strings::StrCat("/job:", task_optionss.job_name, "/replica:0", "/task:",
|
||||
task_options.task);
|
||||
DeviceFactory::AddDevices(task_options.default_session_options, name_prefix,
|
||||
&master_env.local_devices);
|
||||
|
||||
// Create the DeviceMgr before initializing the RPC layer, because that
|
||||
// needs to know how many devices of each kind exist.
|
||||
WorkerEnv worker_env;
|
||||
worker_env.device_mgr = new DeviceMgr(master_env.local_devices);
|
||||
|
||||
// Finish setting up Env for Worker service.
|
||||
string donotcare;
|
||||
CHECK(DeviceNameUtils::SplitDeviceName(master_env.local_devices[0]->name(),
|
||||
&worker_env.worker_name,
|
||||
&donotcare));
|
||||
worker_env.env = Env::Default();
|
||||
|
||||
GrpcChannelCache* channel_cache =
|
||||
NewGrpcChannelCache(task_options.channel_spec);
|
||||
string server_address = channel_cache->TranslateTask(name_prefix);
|
||||
worker_env.worker_cache = NewGrpcWorkerCache(channel_cache);
|
||||
worker_env.graph_mgr = new GraphMgr(&worker_env);
|
||||
worker_env.rendezvous_mgr = new RpcRendezvousMgr(&worker_env);
|
||||
worker_env.compute_pool = ComputePool(task_options.default_session_options);
|
||||
|
||||
// Finish setting up Env for Master service.
|
||||
master_env.env = Env::Default();
|
||||
master_env.ops = OpRegistry::Global();
|
||||
master_env.worker_cache = worker_env.worker_cache;
|
||||
master_env.master_session_factory = internal::NewMasterSession;
|
||||
|
||||
::grpc::ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address,
|
||||
::grpc::InsecureServerCredentials());
|
||||
auto master_service = NewGrpcMasterService(&master_env, &builder);
|
||||
auto worker_service = NewGrpcWorkerService(&worker_env, &builder);
|
||||
// Finally assemble the server.
|
||||
auto server_ = builder.BuildAndStart();
|
||||
|
||||
std::unique_ptr<Thread> master_thread(Env::Default()->StartThread(
|
||||
ThreadOptions(), "master_service_thread",
|
||||
[master_service]() { master_service->HandleRPCsLoop(); }));
|
||||
|
||||
std::unique_ptr<Thread> worker_thread(Env::Default()->StartThread(
|
||||
ThreadOptions(), "worker_service_thread",
|
||||
[worker_service]() { worker_service->HandleRPCsLoop(); }));
|
||||
});
|
||||
|
||||
// The ThreadPool destructor waits until all work is done (i.e. forever).
|
||||
delete thread_pool;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
84
tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
Normal file
84
tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
Normal file
@ -0,0 +1,84 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace test {
|
||||
|
||||
Status TestCluster::MakeTestCluster(const SessionOptions& options, int n,
|
||||
std::unique_ptr<TestCluster>* out_cluster) {
|
||||
CHECK_GE(n, 1);
|
||||
std::unique_ptr<TestCluster> ret(new TestCluster);
|
||||
|
||||
ret->targets_.resize(n);
|
||||
|
||||
std::vector<int> port(n);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
port[i] = testing::PickUnusedPortOrDie();
|
||||
ret->targets_[i] = strings::StrCat("localhost:", port[i]);
|
||||
}
|
||||
|
||||
const string tf_jobs = strings::StrCat("--tf_jobs=localhost|",
|
||||
str_util::Join(ret->targets_, ";"));
|
||||
|
||||
int num_cpus = 1;
|
||||
int num_gpus = 0;
|
||||
auto iter = options.config.device_count().find("CPU");
|
||||
if (iter != options.config.device_count().end()) {
|
||||
num_cpus = iter->second;
|
||||
}
|
||||
iter = options.config.device_count().find("GPU");
|
||||
if (iter != options.config.device_count().end()) {
|
||||
num_gpus = iter->second;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const std::vector<string> argv(
|
||||
{strings::StrCat(testing::TensorFlowSrcRoot(),
|
||||
"/core/distributed_runtime/rpc/grpc_testlib_server"),
|
||||
/* see grpc_testlib_server.cc for flags */
|
||||
tf_jobs, "--tf_job=localhost", strings::StrCat("--tf_task=", i),
|
||||
strings::StrCat("--num_cpus=", num_cpus),
|
||||
strings::StrCat("--num_gpus=", num_gpus)});
|
||||
ret->subprocesses_.emplace_back(testing::CreateSubProcess(argv));
|
||||
bool success = ret->subprocesses_[i]->Start();
|
||||
if (!success) {
|
||||
return errors::Internal("Could not start subprocess");
|
||||
}
|
||||
}
|
||||
|
||||
SessionOptions options_copy(options);
|
||||
options_copy.target = strings::StrCat("grpc://", ret->targets_[0]);
|
||||
std::unique_ptr<GrpcSession> session(new GrpcSession(options_copy));
|
||||
std::vector<DeviceAttributes> device_attributes;
|
||||
ret->devices_ = session->ListDevices();
|
||||
|
||||
*out_cluster = std::move(ret);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TestCluster::~TestCluster() {
|
||||
for (auto& subprocess : subprocesses_) {
|
||||
subprocess->Kill(9);
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace test
|
||||
} // end namespace tensorflow
|
73
tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
Normal file
73
tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
Normal file
@ -0,0 +1,73 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
|
||||
namespace test {
|
||||
|
||||
// Provides a handle to a set of TensorFlow servers (masters and
|
||||
// workers) for testing purposes.
|
||||
//
|
||||
// This class currently runs the servers in separate processes; the
|
||||
// lifetime of this object is coterminous with the lifetimes of those
|
||||
// processes.
|
||||
class TestCluster {
|
||||
public:
|
||||
// Creates a new test cluster based on the given `options` (which
|
||||
// configure the number of devices of each type) and a count of
|
||||
// processes `n`. On success, the test cluster is stored in
|
||||
// *out_cluster, and this function returns OK. Otherwise an error is
|
||||
// returned.
|
||||
static Status MakeTestCluster(const SessionOptions& options, int n,
|
||||
std::unique_ptr<TestCluster>* out_cluster);
|
||||
~TestCluster();
|
||||
|
||||
// Returns a vector of string "<hostname>:<port>" pairs that may be
|
||||
// used as targets to construct a GrpcSession.
|
||||
const std::vector<string>& targets() const { return targets_; }
|
||||
|
||||
// Returns a vector of devices available in this test cluster.
|
||||
const std::vector<DeviceAttributes>& devices() const { return devices_; }
|
||||
|
||||
private:
|
||||
TestCluster() = default;
|
||||
|
||||
std::vector<std::unique_ptr<testing::SubProcess>> subprocesses_;
|
||||
std::vector<string> targets_;
|
||||
std::vector<DeviceAttributes> devices_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TestCluster);
|
||||
};
|
||||
|
||||
} // end namespace test
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
|
91
tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc
Normal file
91
tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace test {
|
||||
|
||||
// ErrorOp::Compute returns an error.
|
||||
REGISTER_OP("Error")
|
||||
.Input("in: T")
|
||||
.Output("out: T")
|
||||
.Attr("T: type")
|
||||
.Attr("message: string");
|
||||
class ErrorOp : public OpKernel {
|
||||
public:
|
||||
explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &errmsg_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
ctx->SetStatus(errors::Internal(errmsg_));
|
||||
}
|
||||
|
||||
private:
|
||||
string errmsg_;
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("Error").Device(DEVICE_CPU), ErrorOp);
|
||||
|
||||
REGISTER_OP("InvalidRefType")
|
||||
.Output("out: Ref(TIn)")
|
||||
.Attr("TIn: type")
|
||||
.Attr("TOut: type");
|
||||
class InvalidRefType : public OpKernel {
|
||||
public:
|
||||
explicit InvalidRefType(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("TOut", &dtout_));
|
||||
output_ = Tensor(dtout_, TensorShape({}));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
ctx->set_output_ref(0, &mu_, &output_);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtout_;
|
||||
mutex mu_;
|
||||
Tensor output_;
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU),
|
||||
InvalidRefType);
|
||||
|
||||
// DelayOp::AsyncCompute sleeps for "micros"-econd and then returns
|
||||
// its input.
|
||||
REGISTER_OP("Delay")
|
||||
.Input("in: T")
|
||||
.Output("out: T")
|
||||
.Attr("T: type")
|
||||
.Attr("micros: int");
|
||||
class DelayOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("micros", µs_));
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
ctx->set_output(0, ctx->input(0));
|
||||
ctx->env()->SchedClosureAfter(micros_, done);
|
||||
}
|
||||
|
||||
private:
|
||||
int64 micros_;
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("Delay").Device(DEVICE_CPU), DelayOp);
|
||||
|
||||
} // namespace test
|
||||
} // namespace tensorflow
|
@ -0,0 +1,92 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "external/grpc/include/grpc++/grpc++.h"
|
||||
#include "external/grpc/include/grpc++/security/credentials.h"
|
||||
#include "external/grpc/include/grpc++/server_builder.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
// This binary starts a TensorFlow server (master and worker) for test purposes.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status ParseFlagsForTask(int argc, char* argv[], GrpcServerOptions* options) {
|
||||
string job_spec;
|
||||
int num_cpus = 1;
|
||||
int num_gpus = 0;
|
||||
const bool parse_result =
|
||||
ParseFlags(&argc, argv, {Flag("tf_jobs", &job_spec), //
|
||||
Flag("tf_job", &options->job_name), //
|
||||
Flag("tf_task", &options->task_index), //
|
||||
Flag("num_cpus", &num_cpus), //
|
||||
Flag("num_gpus", &num_gpus)});
|
||||
if (!parse_result) {
|
||||
return errors::InvalidArgument("Error parsing command-line flags");
|
||||
}
|
||||
|
||||
uint32 my_tasks_per_replica = 0;
|
||||
for (const string& job_str : str_util::Split(job_spec, ',')) {
|
||||
// Split each entry in the flag into 3 pieces, separated by "|".
|
||||
const std::vector<string> job_pieces = str_util::Split(job_str, '|');
|
||||
CHECK_EQ(2, job_pieces.size()) << job_str;
|
||||
const string& job = job_pieces[0];
|
||||
// Does a bit more validation of the tasks_per_replica.
|
||||
const StringPiece spec = job_pieces[1];
|
||||
// job_str is of form <job_name>|<host_ports>.
|
||||
const std::vector<string> host_ports = str_util::Split(spec, ';');
|
||||
uint32 tasks_per_replica = host_ports.size();
|
||||
if (job == options->job_name) {
|
||||
my_tasks_per_replica = tasks_per_replica;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(options->channel_spec.AddHostPortsJob(
|
||||
job, host_ports, tasks_per_replica));
|
||||
LOG(INFO) << "Peer " << job << " " << tasks_per_replica << " {"
|
||||
<< str_util::Join(host_ports, ", ") << "}";
|
||||
}
|
||||
if (my_tasks_per_replica == 0) {
|
||||
return errors::InvalidArgument("Invalid job specification");
|
||||
}
|
||||
|
||||
(*options->default_session_options.config.mutable_device_count())["CPU"] =
|
||||
num_cpus;
|
||||
(*options->default_session_options.config.mutable_device_count())["GPU"] =
|
||||
num_gpus;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
tensorflow::GrpcServerOptions options;
|
||||
tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &options);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not parse flags: " << s.error_message();
|
||||
return -1;
|
||||
}
|
||||
tensorflow::StartTensorFlowServer(options);
|
||||
// NOTE(mrry): Unreachable code.
|
||||
return 0;
|
||||
}
|
48
tensorflow/core/distributed_runtime/rpc/grpc_util.h
Normal file
48
tensorflow/core/distributed_runtime/rpc/grpc_util.h
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "external/grpc/include/grpc++/grpc++.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
inline Status FromGrpcStatus(const ::grpc::Status& s) {
|
||||
if (s.ok()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status(static_cast<tensorflow::error::Code>(s.error_code()),
|
||||
s.error_message());
|
||||
}
|
||||
}
|
||||
|
||||
inline ::grpc::Status ToGrpcStatus(const ::tensorflow::Status& s) {
|
||||
if (s.ok()) {
|
||||
return ::grpc::Status::OK;
|
||||
} else {
|
||||
return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()),
|
||||
s.error_message());
|
||||
}
|
||||
}
|
||||
|
||||
typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
|
85
tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
Normal file
85
tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
Normal file
@ -0,0 +1,85 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GrpcWorkerCache : public WorkerCachePartial {
|
||||
public:
|
||||
explicit GrpcWorkerCache(GrpcChannelCache* channel_cache)
|
||||
: channel_cache_(channel_cache) {
|
||||
// TODO(mrry): Investigate possible performance improvements by
|
||||
// replacing this thread with a threadpool.
|
||||
polling_thread_ = Env::Default()->StartThread(
|
||||
ThreadOptions(), "grpc_worker_cache", [this]() {
|
||||
void* tag;
|
||||
bool ok;
|
||||
while (completion_queue_.Next(&tag, &ok)) {
|
||||
GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
|
||||
callback_tag->OnCompleted(ok);
|
||||
delete callback_tag;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Explicit destructor to control destruction order.
|
||||
~GrpcWorkerCache() override {
|
||||
completion_queue_.Shutdown();
|
||||
delete polling_thread_; // Blocks until thread exits.
|
||||
delete channel_cache_;
|
||||
}
|
||||
|
||||
void ListWorkers(std::vector<string>* workers) override {
|
||||
channel_cache_->ListWorkers(workers);
|
||||
}
|
||||
|
||||
WorkerInterface* CreateWorker(const string& target) override {
|
||||
SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
|
||||
CHECK(channel) << "Channel was null";
|
||||
if (!channel) return nullptr;
|
||||
WorkerInterface* ret =
|
||||
NewGrpcRemoteWorker(channel, &completion_queue_, &logger_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void SetLogging(bool v) override { logger_.SetLogging(v); }
|
||||
|
||||
void ClearLogs() override { logger_.ClearLogs(); }
|
||||
|
||||
bool RetrieveLogs(int64 step_id, StepStats* ss) override {
|
||||
return logger_.RetrieveLogs(step_id, ss);
|
||||
}
|
||||
|
||||
private:
|
||||
GrpcChannelCache* channel_cache_; // Owned.
|
||||
::grpc::CompletionQueue completion_queue_;
|
||||
Thread* polling_thread_; // Owned.
|
||||
WorkerCacheLogger logger_;
|
||||
};
|
||||
|
||||
WorkerCacheInterface* NewGrpcWorkerCache(GrpcChannelCache* cc) {
|
||||
return new GrpcWorkerCache(cc);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
28
tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h
Normal file
28
tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h
Normal file
@ -0,0 +1,28 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The returned WorkerCacheInterface object takes the ownership of "cc".
|
||||
WorkerCacheInterface* NewGrpcWorkerCache(GrpcChannelCache* cc);
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
415
tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
Normal file
415
tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
Normal file
@ -0,0 +1,415 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "external/grpc/include/grpc++/server_builder.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/process_state.h"
|
||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
#include "tensorflow/core/protobuf/worker_service.grpc.pb.h"
|
||||
#include "tensorflow/core/protobuf/worker_service.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
static Tensor empty_tensor(DT_FLOAT);
|
||||
|
||||
class GrpcWorkerService : public AsyncServiceInterface {
|
||||
public:
|
||||
GrpcWorkerService(WorkerEnv* env, ::grpc::ServerBuilder* builder)
|
||||
: env_(env), cancellation_manager_(new CancellationManager) {
|
||||
builder->RegisterService(&worker_service_);
|
||||
cq_ = builder->AddCompletionQueue().release();
|
||||
}
|
||||
|
||||
~GrpcWorkerService() { delete cq_; }
|
||||
|
||||
// This macro creates a new request for the given RPC method name
|
||||
// (e.g., `ENQUEUE_REQUEST(GetStatus);`), and enqueues it on
|
||||
// `this->cq_`.
|
||||
//
|
||||
// This macro is invoked one or more times for each RPC method to
|
||||
// ensure that there are sufficient completion queue entries to
|
||||
// handle incoming requests without blocking.
|
||||
//
|
||||
// The implementation of the request handler for each RPC method
|
||||
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
||||
// to keep accepting new requests.
|
||||
#define ENQUEUE_REQUEST(method) \
|
||||
do { \
|
||||
Call<GrpcWorkerService, grpc::WorkerService::AsyncService, \
|
||||
method##Request, method##Response>:: \
|
||||
EnqueueRequest(&worker_service_, cq_, \
|
||||
&grpc::WorkerService::AsyncService::Request##method, \
|
||||
&GrpcWorkerService::method##Handler); \
|
||||
} while (0)
|
||||
|
||||
// This method blocks forever handling requests from the completion queue.
|
||||
void HandleRPCsLoop() {
|
||||
// TODO(mrry): This may require performance engineering. We can
|
||||
// add more threads to service the completion queue, and add more
|
||||
// of various request types if they are short and frequent.
|
||||
// Currently we allow unbounded numbers of pending calls for each
|
||||
// method, by re-enqueuing a request before the previous one
|
||||
// completes, and we may decide to bound some of the request
|
||||
// types.
|
||||
ENQUEUE_REQUEST(GetStatus);
|
||||
ENQUEUE_REQUEST(CleanupAll);
|
||||
ENQUEUE_REQUEST(RegisterGraph);
|
||||
ENQUEUE_REQUEST(DeregisterGraph);
|
||||
|
||||
// TODO(mrry): Consider enqueuing more of these request types.
|
||||
ENQUEUE_REQUEST(RecvTensor);
|
||||
ENQUEUE_REQUEST(RunGraph);
|
||||
|
||||
ENQUEUE_REQUEST(CleanupGraph);
|
||||
ENQUEUE_REQUEST(Logging);
|
||||
ENQUEUE_REQUEST(Tracing);
|
||||
|
||||
void* tag;
|
||||
bool ok;
|
||||
while (cq_->Next(&tag, &ok)) {
|
||||
UntypedCall<GrpcWorkerService>::Tag* callback_tag =
|
||||
static_cast<UntypedCall<GrpcWorkerService>::Tag*>(tag);
|
||||
callback_tag->OnCompleted(this, ok);
|
||||
delete callback_tag;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
WorkerEnv* env_; // Not owned.
|
||||
::grpc::ServerCompletionQueue* cq_; // Owned.
|
||||
|
||||
grpc::WorkerService::AsyncService worker_service_;
|
||||
|
||||
mutex mu_;
|
||||
CancellationManager* cancellation_manager_ GUARDED_BY(mu_);
|
||||
|
||||
// The following section contains one request handler method per
|
||||
// RPC. The The `FooHandler` method is called (indirectly) by
|
||||
// `HandleRPCsLoop()` when the next Foo RPC is received. Each
|
||||
// `FooHandler` call schedules a closure on `env_->compute_pool`,
|
||||
// and is responsible for requesting the next Foo call by calling
|
||||
// `ENQUEUE_REQUEST(Foo)`.
|
||||
|
||||
template <class RequestMessage, class ResponseMessage>
|
||||
using WorkerCall = Call<GrpcWorkerService, grpc::WorkerService::AsyncService,
|
||||
RequestMessage, ResponseMessage>;
|
||||
|
||||
void GetStatusHandler(WorkerCall<GetStatusRequest, GetStatusResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() {
|
||||
DeviceMgr* dm = env_->device_mgr;
|
||||
std::vector<DeviceAttributes> devices;
|
||||
dm->ListDeviceAttributes(&devices);
|
||||
call->response.mutable_device_attributes()->Reserve(devices.size());
|
||||
for (size_t i = 0; i < devices.size(); i++) {
|
||||
call->response.add_device_attributes()->Swap(&devices[i]);
|
||||
}
|
||||
call->SendResponse(::grpc::Status::OK);
|
||||
});
|
||||
ENQUEUE_REQUEST(GetStatus);
|
||||
}
|
||||
|
||||
void CleanupAllHandler(
|
||||
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() {
|
||||
std::vector<string> containers;
|
||||
for (const auto& c : call->request.container()) containers.push_back(c);
|
||||
env_->device_mgr->ClearContainers(containers);
|
||||
call->SendResponse(::grpc::Status::OK);
|
||||
});
|
||||
ENQUEUE_REQUEST(CleanupAll);
|
||||
}
|
||||
|
||||
void RegisterGraphHandler(
|
||||
WorkerCall<RegisterGraphRequest, RegisterGraphResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() {
|
||||
Status s = env_->graph_mgr->Register(
|
||||
call->request.session_handle(), call->request.graph_def(),
|
||||
call->request.graph_options(), call->response.mutable_graph_handle());
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
ENQUEUE_REQUEST(RegisterGraph);
|
||||
}
|
||||
|
||||
void DeregisterGraphHandler(
|
||||
WorkerCall<DeregisterGraphRequest, DeregisterGraphResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() {
|
||||
Status s = env_->graph_mgr->Deregister(call->request.graph_handle());
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
ENQUEUE_REQUEST(DeregisterGraph);
|
||||
}
|
||||
|
||||
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
|
||||
ENQUEUE_REQUEST(RunGraph);
|
||||
}
|
||||
|
||||
void RecvTensorHandler(
|
||||
WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() { DoRecvTensor(call); });
|
||||
ENQUEUE_REQUEST(RecvTensor);
|
||||
}
|
||||
|
||||
void CleanupGraphHandler(
|
||||
WorkerCall<CleanupGraphRequest, CleanupGraphResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() {
|
||||
const int64 step_id = call->request.step_id();
|
||||
env_->rendezvous_mgr->Cleanup(step_id);
|
||||
call->SendResponse(::grpc::Status::OK);
|
||||
});
|
||||
ENQUEUE_REQUEST(CleanupGraph);
|
||||
}
|
||||
|
||||
void LoggingHandler(WorkerCall<LoggingRequest, LoggingResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() {
|
||||
Status s = DoLogging(call);
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
ENQUEUE_REQUEST(Logging);
|
||||
}
|
||||
|
||||
void TracingHandler(WorkerCall<TracingRequest, TracingResponse>* call) {
|
||||
SchedClosure([this, call]() {
|
||||
Status s = DoTracing(call);
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
ENQUEUE_REQUEST(Tracing);
|
||||
}
|
||||
#undef ENQUEUE_REQUEST
|
||||
|
||||
private:
|
||||
// The following section contains the implementation of RunGraph()
|
||||
// RecvTensor(), Logging(), and Tracing(), which are the four
|
||||
// non-trivial and potentially long-running RPCs performed by a
|
||||
// TensorFlow worker.
|
||||
|
||||
void AbortStep(int64 step_id) {
|
||||
Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
|
||||
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
|
||||
// Delay a bit before aborting the step. This way, the root
|
||||
// cause may return first back to the client instead of this
|
||||
// cancellation generated abort error.
|
||||
rendez->StartAbort(errors::Aborted("Step ", step_id));
|
||||
rendez->Unref();
|
||||
});
|
||||
}
|
||||
|
||||
Status PrepareRunGraph(const RunGraphRequest& req, GraphMgr::NamedTensors* in,
|
||||
GraphMgr::NamedTensors* out) {
|
||||
if (req.send_size() > 0) {
|
||||
// TODO(zhifengc): Let the caller decide on which device to
|
||||
// allocate the tensor.
|
||||
Device* cpu_dev = nullptr;
|
||||
TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice("CPU:0", &cpu_dev));
|
||||
AllocatorAttributes alloc_attrs;
|
||||
Tensor val;
|
||||
for (const NamedTensor& entry : req.send()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
cpu_dev->MakeTensorFromProto(entry.val(), alloc_attrs, &val));
|
||||
in->insert({entry.key(), val});
|
||||
}
|
||||
}
|
||||
for (const string& key : req.recv_key()) {
|
||||
out->insert({key, empty_tensor});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DoRunGraph(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
||||
const int64 step_id = call->request.step_id();
|
||||
TRACEPRINTF("RunGraph: %lld", step_id);
|
||||
GraphMgr::NamedTensors in;
|
||||
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
|
||||
Status s = PrepareRunGraph(call->request, &in, out);
|
||||
if (!s.ok()) {
|
||||
delete out;
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
return;
|
||||
}
|
||||
StepStatsCollector* collector = nullptr;
|
||||
// TODO(mrry): Collect results from a profiler if available.
|
||||
CancellationManager* cm = new CancellationManager;
|
||||
call->SetCancelCallback([this, cm, step_id]() {
|
||||
cm->StartCancel();
|
||||
AbortStep(step_id);
|
||||
});
|
||||
CancellationToken token;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
token = cancellation_manager_->get_cancellation_token();
|
||||
cancellation_manager_->RegisterCallback(token,
|
||||
[cm]() { cm->StartCancel(); });
|
||||
}
|
||||
env_->graph_mgr->ExecuteAsync(
|
||||
call->request.graph_handle(), step_id, call->request.exec_opts(),
|
||||
collector, cm, in, out, [this, call, cm, out, token](Status s) {
|
||||
call->ClearCancelCallback();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
cancellation_manager_->DeregisterCallback(token);
|
||||
}
|
||||
delete cm;
|
||||
|
||||
if (s.ok()) {
|
||||
for (const auto& p : *out) {
|
||||
const string& key = p.first;
|
||||
const Tensor& val = p.second;
|
||||
auto* recv = call->response.add_recv();
|
||||
recv->set_key(key);
|
||||
// TODO(zhifengc): Deal with gpu -> cpu copy.
|
||||
TensorProto* proto = recv->mutable_val();
|
||||
val.AsProtoField(proto);
|
||||
}
|
||||
}
|
||||
delete out;
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
}
|
||||
|
||||
// Helper for RecvTensor. Validates "key" and returns the source
|
||||
// device in "*src_dev".
|
||||
Status PrepareRecvTensor(const string& key, Device** src_dev) {
|
||||
// Validate the key.
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||
|
||||
// Figures out which device the tensor is hosted on.
|
||||
TF_RETURN_IF_ERROR(
|
||||
env_->device_mgr->LookupDevice(parsed.src_device, src_dev));
|
||||
|
||||
// Does the device have the right incarnation number we expect?
|
||||
if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
|
||||
return errors::Aborted(
|
||||
"RecvTensor expects a different device incarnation: ",
|
||||
parsed.src_incarnation, " vs. ",
|
||||
(*src_dev)->attributes().incarnation(),
|
||||
". Your worker job was probably restarted. Check your "
|
||||
"worker job for the reason why it was restarted.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DoRecvTensor(WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) {
|
||||
const int64 step_id = call->request.step_id();
|
||||
const string& key = call->request.rendezvous_key();
|
||||
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
||||
Device* src_dev = nullptr;
|
||||
Status s = PrepareRecvTensor(key, &src_dev);
|
||||
if (!s.ok()) {
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
return;
|
||||
}
|
||||
|
||||
// Request the tensor associated with the rendezvous key. Any time
|
||||
// while waiting for the tensor to be produced, up until the start
|
||||
// of execution of the callback lambda body below, an RPC
|
||||
// cancellation should abort the rendezvous.
|
||||
call->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
|
||||
env_->rendezvous_mgr->RecvLocalAsync(
|
||||
step_id, key,
|
||||
[this, call, src_dev](const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& val, const bool is_dead) {
|
||||
call->ClearCancelCallback();
|
||||
Status s = status;
|
||||
if (s.ok()) {
|
||||
// DMA can only be used for Tensors that do not fall into
|
||||
// the following three odd edge cases: 1) a zero-size
|
||||
// buffer, 2) a dead tensor which has an uninit value, and
|
||||
// 3) the tensor has the on_host allocation attribute,
|
||||
// i.e. it's in CPU RAM *independent of its assigned
|
||||
// device type*.
|
||||
// const size_t bytes = is_dead ? 0 : val.TotalBytes();
|
||||
const bool on_host = send_args.alloc_attrs.on_host();
|
||||
const DeviceContext* send_dev_context = send_args.device_context;
|
||||
call->response.set_is_dead(is_dead);
|
||||
StatusCallback response_ready = [call](const Status& s) {
|
||||
// The value is now ready to be returned on the wire.
|
||||
call->response.set_send_start_micros(Env::Default()->NowMicros());
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
};
|
||||
{
|
||||
// Non-DMA cases.
|
||||
if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
|
||||
CHECK(send_dev_context)
|
||||
<< "send dev name: " << src_dev->name()
|
||||
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
|
||||
// "val" is on a GPU. Uses GPUUtil to fill the response proto.
|
||||
GPUUtil::SetProtoFromGPU(val, src_dev, send_dev_context,
|
||||
call->response.mutable_tensor(),
|
||||
is_dead, response_ready);
|
||||
} else {
|
||||
// "val" is in CPU memory.
|
||||
TensorProto* proto = call->response.mutable_tensor();
|
||||
val.AsProtoTensorContent(proto);
|
||||
response_ready(Status::OK());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// !s.ok()
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Status DoLogging(WorkerCall<LoggingRequest, LoggingResponse>* call) {
|
||||
// TODO(mrry): Platform-specific tracing support.
|
||||
return errors::Unimplemented("Logging");
|
||||
}
|
||||
|
||||
Status DoTracing(WorkerCall<TracingRequest, TracingResponse>* call) {
|
||||
// TODO(mrry): Platform-specific tracing support.
|
||||
return errors::Unimplemented("Tracing");
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
AsyncServiceInterface* NewGrpcWorkerService(WorkerEnv* env,
|
||||
::grpc::ServerBuilder* builder) {
|
||||
return new GrpcWorkerService(env, builder);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,34 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
|
||||
|
||||
namespace grpc {
|
||||
class ServerBuilder;
|
||||
} // namespace grpc
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class AsyncServiceInterface;
|
||||
class WorkerEnv;
|
||||
|
||||
// Returns an implementation of WorkerService rpc service.
|
||||
AsyncServiceInterface* NewGrpcWorkerService(WorkerEnv* env,
|
||||
::grpc::ServerBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
|
196
tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
Normal file
196
tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
Normal file
@ -0,0 +1,196 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
public:
|
||||
RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
|
||||
: BaseRemoteRendezvous(env, step_id, false) {}
|
||||
|
||||
protected:
|
||||
void RecvFromRemoteAsync(const string& key,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
private:
|
||||
~RpcRemoteRendezvous() override {}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
|
||||
};
|
||||
|
||||
// Used only to retrieve tensors from remote processes.
|
||||
class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
public:
|
||||
RpcRecvTensorCall(WorkerCacheInterface* wc, WorkerInterface* wi,
|
||||
int64 step_id, const string& key,
|
||||
const string& remote_dev, Allocator* allocator,
|
||||
Device* dst_device)
|
||||
: wi_(wi),
|
||||
wc_(wc),
|
||||
remote_dev_(remote_dev),
|
||||
allocator_(allocator),
|
||||
dst_(dst_device) {
|
||||
req_.set_step_id(step_id);
|
||||
req_.set_rendezvous_key(key);
|
||||
}
|
||||
|
||||
~RpcRecvTensorCall() override { delete wi_; }
|
||||
|
||||
void Start(std::function<void()> recv_done) override {
|
||||
StartRTCall(recv_done);
|
||||
}
|
||||
|
||||
void StartAbort(const Status& s) override {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(s);
|
||||
}
|
||||
opts_.StartCancel();
|
||||
}
|
||||
|
||||
Status status() const override {
|
||||
mutex_lock l(mu_);
|
||||
return status_;
|
||||
}
|
||||
|
||||
const TensorProto& tensor_proto() const { return resp_.tensor(); }
|
||||
|
||||
const RecvTensorResponse& response() const { return resp_; }
|
||||
|
||||
bool is_dead() const { return resp_.is_dead(); }
|
||||
|
||||
private:
|
||||
// Start the main RecvTensor call, checking for an async abort.
|
||||
void StartRTCall(std::function<void()> recv_done) {
|
||||
wi_->RecvTensorAsync(&opts_, &req_, &resp_,
|
||||
nullptr /* TensorBufAllocator */,
|
||||
// done callback
|
||||
[this, recv_done](const Status& s) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(s);
|
||||
}
|
||||
recv_done();
|
||||
});
|
||||
}
|
||||
|
||||
WorkerInterface* wi_; // Owned.
|
||||
WorkerCacheInterface* wc_; // Not owned.
|
||||
string remote_dev_;
|
||||
Allocator* allocator_;
|
||||
Device* dst_;
|
||||
CallOptions opts_;
|
||||
RecvTensorRequest req_;
|
||||
RecvTensorResponse resp_;
|
||||
|
||||
mutable mutex mu_;
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
|
||||
};
|
||||
|
||||
|
||||
void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
const string& key, const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args, DoneCallback done) {
|
||||
Status s;
|
||||
|
||||
// key.src_device identifies a remote device.
|
||||
string src_worker;
|
||||
string src_rel_device;
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker,
|
||||
&src_rel_device)) {
|
||||
s = errors::Internal(parsed.src_device,
|
||||
" is invalid remote source device.");
|
||||
}
|
||||
WorkerCacheInterface* worker_cache = env_->worker_cache;
|
||||
if (s.ok() && worker_cache == nullptr) {
|
||||
s = errors::Internal("No remote worker cache available.");
|
||||
}
|
||||
WorkerInterface* rwi = env_->worker_cache->CreateWorker(src_worker);
|
||||
if (s.ok() && rwi == nullptr) {
|
||||
s = errors::Internal("No worker known as ", src_worker);
|
||||
}
|
||||
|
||||
Device* dst_device;
|
||||
if (s.ok()) {
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor{}, false);
|
||||
return;
|
||||
}
|
||||
Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs);
|
||||
|
||||
// Prepare a RecvTensor call that can handle being aborted.
|
||||
RpcRecvTensorCall* call =
|
||||
new RpcRecvTensorCall(worker_cache, rwi, step_id_, key,
|
||||
parsed.src_device, allocator, dst_device);
|
||||
|
||||
// Record "call" in active_ so that it can be aborted cleanly.
|
||||
RegisterCall(call);
|
||||
|
||||
// Start "call".
|
||||
call->Start([this, call, parsed, recv_args, done]() {
|
||||
// Removes "call" from active_. Prevent StartAbort().
|
||||
DeregisterCall(call);
|
||||
// If StartAbort was called prior to DeregisterCall, then the
|
||||
// current status should be bad.
|
||||
Status s = call->status();
|
||||
Tensor val;
|
||||
if (s.ok()) {
|
||||
Device* dst_device;
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
if (s.ok()) {
|
||||
s = dst_device->MakeTensorFromProto(call->tensor_proto(),
|
||||
recv_args.alloc_attrs, &val);
|
||||
}
|
||||
}
|
||||
done(s, Args(), recv_args, val, call->is_dead());
|
||||
delete call;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) {
|
||||
return new RpcRemoteRendezvous(worker_env, step_id);
|
||||
}
|
||||
|
||||
|
||||
} // end namespace tensorflow
|
57
tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
Normal file
57
tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// RendezvousMgr keeps track of a set of local rendezvous instances.
|
||||
// All tensors sent by this worker are buffered in a RendezvousMgr
|
||||
// until the tensor is received. Each global unique "step_id"
|
||||
// corresponds to one local rendezvous instance managed by a
|
||||
// RendezvousMgr.
|
||||
//
|
||||
// E.g.,
|
||||
// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
|
||||
// fork execution of an graph executor using "rendez" on thread 1;
|
||||
// fork execution of another graph executor using "rendez" on thread 2;
|
||||
// ...
|
||||
// join threads 1 and 2;
|
||||
//
|
||||
// In the example above, execution in thread 1 and 2 communicates with
|
||||
// each other by send/recv operations through the "rend".
|
||||
//
|
||||
// Tensors sent and recved through rendezvous managed by this
|
||||
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
|
||||
class RpcRendezvousMgr : public BaseRendezvousMgr {
|
||||
public:
|
||||
explicit RpcRendezvousMgr(const WorkerEnv* env) : BaseRendezvousMgr(env) {}
|
||||
|
||||
protected:
|
||||
BaseRemoteRendezvous* Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) override;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
|
@ -0,0 +1,172 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// string -> Tensor<string>
|
||||
Tensor V(const string& content) {
|
||||
Tensor tensor(DT_STRING, TensorShape({}));
|
||||
tensor.scalar<string>()() = content;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Tensor<string> -> string
|
||||
string V(const Tensor& tensor) {
|
||||
CHECK_EQ(tensor.dtype(), DT_STRING);
|
||||
CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
|
||||
return tensor.scalar<string>()();
|
||||
}
|
||||
|
||||
TEST(RpcRendezvousMgrTest, LocalSendRecv) {
|
||||
WorkerEnv env;
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const int64 step_id = 123;
|
||||
const string key = Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
{
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
core::ScopedUnref unref(rendez);
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
|
||||
}
|
||||
{
|
||||
Tensor val(DT_FLOAT);
|
||||
bool val_dead = false;
|
||||
TF_ASSERT_OK(rmgr.RecvLocal(step_id, key, &val, &val_dead));
|
||||
EXPECT_EQ(V(val), "peach");
|
||||
}
|
||||
rmgr.Cleanup(step_id);
|
||||
}
|
||||
|
||||
TEST(RpcRendezvousMgrTest, LocalAbort) {
|
||||
WorkerEnv env;
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const string key = Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
{ // Explicit Abort().
|
||||
const int64 step_id = 123;
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
core::ScopedUnref unref(rendez);
|
||||
SchedClosure([env, rendez]() {
|
||||
env.env->SleepForMicroseconds(100 * 1000);
|
||||
rendez->StartAbort(errors::Aborted(""));
|
||||
});
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
|
||||
}
|
||||
{ // Cleanup causes Abort().
|
||||
const int64 step_id = 321;
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
core::ScopedUnref unref(rendez);
|
||||
SchedClosure([env, &rmgr, step_id]() {
|
||||
env.env->SleepForMicroseconds(100 * 1000);
|
||||
rmgr.Cleanup(step_id);
|
||||
});
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(RpcRendezvousMgrTest, CleanupAll) {
|
||||
WorkerEnv env;
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const string key = Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
{
|
||||
const int64 step_id = 123;
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
core::ScopedUnref unref(rendez);
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
|
||||
rmgr.CleanupAll();
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
|
||||
}
|
||||
}
|
||||
|
||||
class DummyDeviceContext : public DeviceContext {
|
||||
public:
|
||||
explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
|
||||
~DummyDeviceContext() override {}
|
||||
int stream_id() const { return stream_id_; }
|
||||
|
||||
private:
|
||||
const int stream_id_;
|
||||
};
|
||||
|
||||
TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
|
||||
DummyDeviceContext* dc = new DummyDeviceContext(123);
|
||||
|
||||
WorkerEnv env;
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const int64 step_id = 123;
|
||||
const string key = Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
{
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
core::ScopedUnref unref(rendez);
|
||||
Rendezvous::Args args;
|
||||
args.device_context = dc;
|
||||
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
|
||||
}
|
||||
{
|
||||
Notification n;
|
||||
rmgr.RecvLocalAsync(
|
||||
step_id, key, [&n](const Status& s, const Rendezvous::Args send_args,
|
||||
const Rendezvous::Args recv_args, const Tensor& val,
|
||||
bool is_dead) {
|
||||
auto send_dev_context =
|
||||
static_cast<DummyDeviceContext*>(send_args.device_context);
|
||||
CHECK_EQ(123, send_dev_context->stream_id());
|
||||
CHECK_EQ(V(val), "peach");
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
}
|
||||
rmgr.Cleanup(step_id);
|
||||
dc->Unref();
|
||||
}
|
||||
|
||||
// NOTE: Remote Send/Recv is better tested in worker_test.cc
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,309 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/simple_graph_execution_state.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/simple_placer.h"
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/graph/costmodel.h"
|
||||
#include "tensorflow/core/graph/dot.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/subgraph.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
string BuildGraphOptions::DebugString() const {
|
||||
string rv = "Feed endpoints: ";
|
||||
for (auto& s : feed_endpoints) {
|
||||
strings::StrAppend(&rv, s, ", ");
|
||||
}
|
||||
strings::StrAppend(&rv, "\nFetch endpoints: ");
|
||||
for (auto& s : fetch_endpoints) {
|
||||
strings::StrAppend(&rv, s, ", ");
|
||||
}
|
||||
strings::StrAppend(&rv, "\nTarget nodes: ");
|
||||
for (auto& s : target_nodes) {
|
||||
strings::StrAppend(&rv, s, ", ");
|
||||
}
|
||||
return rv;
|
||||
}
|
||||
|
||||
SimpleGraphExecutionState::SimpleGraphExecutionState(
|
||||
const OpRegistryInterface* ops,
|
||||
const SimpleGraphExecutionStateOptions& options)
|
||||
: ops_(ops),
|
||||
device_set_(options.device_set),
|
||||
session_options_(options.session_options),
|
||||
base_(nullptr),
|
||||
placed_(nullptr) {
|
||||
// TODO(mrry): Publish placement visualizations or handle the log
|
||||
// placement option.
|
||||
}
|
||||
|
||||
SimpleGraphExecutionState::~SimpleGraphExecutionState() {
|
||||
mutex_lock l(mu_);
|
||||
delete base_;
|
||||
delete placed_;
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::Create(GraphDef* graph_def) {
|
||||
if (original_graph_def_.node_size() > 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot call Create on SimpleGraphExecutionState twice");
|
||||
}
|
||||
|
||||
original_graph_def_.Swap(graph_def);
|
||||
VLOG(2) << "Incoming def: " << original_graph_def_.DebugString();
|
||||
return AddDefaultAttrsToGraphDef(&original_graph_def_, *ops_, 0);
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::Extend(
|
||||
const GraphDef& extension_def, SimpleGraphExecutionState** out) const {
|
||||
std::unordered_set<string> new_names;
|
||||
// 1. Build an index of the new node names.
|
||||
for (const NodeDef& node : extension_def.node()) {
|
||||
new_names.insert(node.name());
|
||||
}
|
||||
|
||||
// 2. Add the non-duplicates from the old graph to the new graph.
|
||||
// Return an error if the same node name appears in both the
|
||||
// old graph and the extension.
|
||||
GraphDef gdef;
|
||||
for (const NodeDef& node : original_graph_def_.node()) {
|
||||
if (new_names.count(node.name()) == 0) {
|
||||
*gdef.add_node() = node;
|
||||
} else {
|
||||
return errors::InvalidArgument(tensorflow::strings::Printf(
|
||||
"GraphDef argument to Extend includes node '%s', which was created "
|
||||
"by a previous call to Create or Extend in this session.",
|
||||
node.name().c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
int old_node_size = gdef.node_size();
|
||||
gdef.mutable_node()->MergeFrom(extension_def.node());
|
||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *ops_, old_node_size));
|
||||
|
||||
// 3. Add the extension.
|
||||
SimpleGraphExecutionStateOptions combined_options;
|
||||
combined_options.device_set = device_set_;
|
||||
|
||||
SimpleGraphExecutionState* new_execution_state =
|
||||
new SimpleGraphExecutionState(ops_, combined_options);
|
||||
Status new_execution_state_status = new_execution_state->Create(&gdef);
|
||||
if (!new_execution_state_status.ok()) {
|
||||
delete new_execution_state;
|
||||
return new_execution_state_status;
|
||||
}
|
||||
*out = new_execution_state;
|
||||
|
||||
// Ensure that any state created in the precursor is accessible in the
|
||||
// new graph.
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
for (const auto& placement : stateful_placements_) {
|
||||
(*out)->stateful_placements_.insert(placement);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(mrry): This is likely to be used for non-throughput-sensitive
|
||||
// interactive workloads, but in future we may want to transfer other
|
||||
// parts of the placement and/or cost model.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::InitBaseGraph() {
|
||||
std::unique_ptr<Graph> new_base(new Graph(ops_));
|
||||
GraphConstructorOptions opts;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToGraph(opts, original_graph_def_, new_base.get()));
|
||||
for (const Node* n : new_base->nodes()) {
|
||||
VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
|
||||
node_name_to_cost_id_map_[n->name()] = n->cost_id();
|
||||
}
|
||||
|
||||
Status status = PreliminaryPlace(*new_base);
|
||||
if (!status.ok()) {
|
||||
node_name_to_cost_id_map_.clear();
|
||||
return status;
|
||||
}
|
||||
base_ = new_base.release();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::GlobalNodeDefByName(const string& name,
|
||||
NodeDef* out) {
|
||||
NodeNameToCostIdMap::const_iterator iter =
|
||||
node_name_to_cost_id_map_.find(name);
|
||||
if (iter != node_name_to_cost_id_map_.end()) {
|
||||
mutex_lock l(mu_); // could use reader lock
|
||||
const Node* node = placed_->FindNodeId(iter->second);
|
||||
if (node) {
|
||||
*out = node->def();
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return errors::NotFound("Node name: ", name);
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::PreliminaryPlace(const Graph& base) {
|
||||
VLOG(1) << "PreliminaryPlace";
|
||||
Graph* ng = new Graph(ops_);
|
||||
|
||||
CopyGraph(base, ng);
|
||||
Status status = DoPlace(ng);
|
||||
if (!status.ok()) {
|
||||
delete ng;
|
||||
} else {
|
||||
delete placed_;
|
||||
placed_ = ng;
|
||||
FreezeStatefulNodes(true /*is_prelim*/);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
void SimpleGraphExecutionState::FreezeStatefulNodes(bool is_prelim) {
|
||||
if (is_prelim) {
|
||||
// During the preliminary placement every stateful Node got placed
|
||||
// somewhere, and we need to remember where, so it doesn't move.
|
||||
for (Node* n : placed_->nodes()) {
|
||||
if (n->op_def().is_stateful()) {
|
||||
stateful_placements_[n->name()] = n->assigned_device_name();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// During later placements it's possible for new stateful nodes to
|
||||
// appear. They are noticed while we're pinning the pre-existing
|
||||
// stateful nodes to their prior positions, and after they've been
|
||||
// placed this function is entered to record their placements.
|
||||
for (Node* n : missing_stateful_placements_) {
|
||||
CHECK(n->op_def().is_stateful());
|
||||
stateful_placements_[n->name()] = n->assigned_device_name();
|
||||
}
|
||||
missing_stateful_placements_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void SimpleGraphExecutionState::PlaceStatefulNodes(Graph* graph) {
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (n->op_def().is_stateful()) {
|
||||
PlaceMap::const_iterator iter = stateful_placements_.find(n->name());
|
||||
if (iter == stateful_placements_.end()) {
|
||||
// NOTE(tucker): I don't understand why this can occur. So far,
|
||||
// I've only seen it in eval instances, started from a checkpoint.
|
||||
missing_stateful_placements_.push_back(n);
|
||||
} else {
|
||||
n->set_assigned_device_name(iter->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::DoPlace(Graph* graph) {
|
||||
Status status;
|
||||
// TODO(mrry): Port other placement algorithms from whitepaper.
|
||||
return SimplePlacement(graph);
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::BuildGraph(const BuildGraphOptions& options,
|
||||
ClientGraph** out) {
|
||||
VLOG(1) << "BuildGraph";
|
||||
mutex_lock l(mu_);
|
||||
// Lazily initialize the base graph.
|
||||
if (base_ == nullptr) {
|
||||
TF_RETURN_IF_ERROR(InitBaseGraph());
|
||||
}
|
||||
|
||||
if (!base_ || !placed_) {
|
||||
return ::tensorflow::errors::Internal(
|
||||
"There was a problem building the graph.");
|
||||
}
|
||||
|
||||
std::unique_ptr<ClientGraph> cgraph(new ClientGraph(ops_));
|
||||
CopyGraph(*placed_, &cgraph->graph);
|
||||
|
||||
// Extract the subset of the graph that needs to be run, adding feed/fetch
|
||||
// ops as needed.
|
||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||
&cgraph->graph, options.feed_endpoints, options.fetch_endpoints,
|
||||
options.target_nodes, device_set_->client_device()->attributes()));
|
||||
|
||||
// Copy the extracted graph in order to make its node ids dense,
|
||||
// since the local CostModel used to record its stats is sized by
|
||||
// the largest node id.
|
||||
{
|
||||
std::unique_ptr<ClientGraph> dense_copy(new ClientGraph(ops_));
|
||||
CopyGraph(cgraph->graph, &dense_copy->graph);
|
||||
cgraph = std::move(dense_copy);
|
||||
}
|
||||
|
||||
// TODO(vrv): We should check invariants of the graph here.
|
||||
|
||||
*out = cgraph.release();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::DeviceIsCompatible(
|
||||
Node* n, const Device* device) const {
|
||||
if (!n->def().device().empty()) {
|
||||
DeviceNameUtils::ParsedName pname;
|
||||
if (!DeviceNameUtils::ParseFullName(n->def().device(), &pname)) {
|
||||
return AttachDef(
|
||||
errors::InvalidArgument("Malformed device specification '",
|
||||
n->def().device(), "'"),
|
||||
n->def());
|
||||
}
|
||||
std::vector<Device*> devices;
|
||||
device_set_->FindMatchingDevices(pname, &devices);
|
||||
for (auto d : devices) {
|
||||
if (d == device) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
return AttachDef(
|
||||
errors::InvalidArgument(
|
||||
"Specified device '", n->def().device(),
|
||||
"' not compatible with device of ref connection: ", device->name()),
|
||||
n->def());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SimpleGraphExecutionState::SimplePlacement(Graph* graph) {
|
||||
SimplePlacer placer(graph, device_set_, &node_name_to_cost_id_map_,
|
||||
session_options_);
|
||||
// TODO(mrry): Consider making the SimplePlacer cancelable.
|
||||
return placer.Run();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,156 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/distributed_runtime/build_graph_options.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/graph/costmodel.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class SessionOptions;
|
||||
class StepStats;
|
||||
class Timeline;
|
||||
|
||||
struct SimpleGraphExecutionStateOptions {
|
||||
const DeviceSet* device_set = nullptr;
|
||||
const SessionOptions* session_options = nullptr;
|
||||
};
|
||||
|
||||
// A ClientGraph is simply a sub-graph of the full graph as induced by
|
||||
// BuildGraphOptions.
|
||||
struct ClientGraph {
|
||||
Graph graph;
|
||||
explicit ClientGraph(const OpRegistryInterface* ops) : graph(ops) {}
|
||||
int32 placement_version;
|
||||
};
|
||||
|
||||
// SimpleGraphExecutionState is responsible for generating an
|
||||
// executable ClientGraph from the original GraphDef that specifies
|
||||
// the complete graph and from BuildGraphOptions which specifies
|
||||
// input/output nodes.
|
||||
//
|
||||
// An executable Graph differs from a GraphDef by being Placed,
|
||||
// meaning that each Node is assigned to a single Device in the
|
||||
// available set.
|
||||
//
|
||||
// When SimpleGraphExecutionState is first constructed it instantiates
|
||||
// a full Graph from the provided GraphDef, and places it, using only
|
||||
// the static device assignments from the GraphDef. Nodes without are
|
||||
// currently placed in a very naive way. Since stateful Nodes cannot
|
||||
// be moved after initial placement, it is important that stateful
|
||||
// Nodes get sensible initial device assignments in the graph
|
||||
// definition.
|
||||
//
|
||||
// Subsequently, SimpleGraphExecutionState generates a ClientGraph on
|
||||
// demand, which is a sub-graph of the latest placement of the full
|
||||
// Graph. MasterSession uses such a ClientGraph to execute one or
|
||||
// more similar client requests.
|
||||
//
|
||||
// SimpleGraphExecutionState is thread-safe.
|
||||
|
||||
class SimpleGraphExecutionState {
|
||||
public:
|
||||
SimpleGraphExecutionState(const OpRegistryInterface* ops,
|
||||
const SimpleGraphExecutionStateOptions& options);
|
||||
|
||||
virtual ~SimpleGraphExecutionState();
|
||||
|
||||
// Initializes the SimpleGraphExecutionState with 'graph_def'. Can only be
|
||||
// called once on an original SimpleGraphExecutionState. Callee may modify
|
||||
// 'graph_def'.
|
||||
Status Create(GraphDef* graph_def);
|
||||
|
||||
// Creates a new SimpleGraphExecutionState representing the
|
||||
// concatenation of this graph, and the graph defined by
|
||||
// "extension_def". The same name may not be used to define a node
|
||||
// in both this graph and "extension_def".
|
||||
//
|
||||
// If successful, returns OK and the caller takes ownership of "*out".
|
||||
// Otherwise returns an error and does not modify "*out".
|
||||
//
|
||||
// NOTE(mrry): This method respects the placement of stateful nodes in
|
||||
// in *this, but currently does not transfer any other placement
|
||||
// or cost model information to the new graph.
|
||||
Status Extend(const GraphDef& extension_def,
|
||||
SimpleGraphExecutionState** out) const;
|
||||
|
||||
// Builds a ClientGraph (a sub-graph of the full graph as induced by
|
||||
// the Node set specified in "options"). If successful, returns OK
|
||||
// and the caller takes the ownership of "*out". Otherwise, returns
|
||||
// an error.
|
||||
Status BuildGraph(const BuildGraphOptions& options, ClientGraph** out);
|
||||
|
||||
// Returns OK if the named node is found in the placed full graph owned
|
||||
// by this execution_state, and sets *out to the NodeDef for that node.
|
||||
// It may not exist if name is of a Node added for a particular subgraph
|
||||
// execution, e.g. a send, recv or feed node.
|
||||
Status GlobalNodeDefByName(const string& name, NodeDef* out);
|
||||
|
||||
private:
|
||||
mutable mutex mu_;
|
||||
|
||||
Status InitBaseGraph() EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status PreliminaryPlace(const Graph& graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
void FreezeStatefulNodes(bool is_prelim) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
void PlaceStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status DoPlace(Graph* graph);
|
||||
Status SimplePlacement(Graph* graph);
|
||||
// Return an OK status if "n" can be assigned to "device".
|
||||
Status DeviceIsCompatible(Node* n, const Device* device) const;
|
||||
|
||||
const OpRegistryInterface* const ops_; // Not owned
|
||||
GraphDef original_graph_def_; // Immutable after ctor.
|
||||
const DeviceSet* device_set_; // Not owned
|
||||
const SessionOptions* session_options_; // Not owned
|
||||
|
||||
// Original graph before we make any placement decisions.
|
||||
Graph* base_ GUARDED_BY(mu_);
|
||||
|
||||
// Full graph, placed on the complete set of devices, as a whole.
|
||||
Graph* placed_ GUARDED_BY(mu_);
|
||||
|
||||
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
|
||||
// is true, such as "params" and "queue" nodes. Once placed these
|
||||
// nodes can not be moved to a different device. Maps node names to
|
||||
// device names.
|
||||
typedef std::unordered_map<string, string> PlaceMap;
|
||||
PlaceMap stateful_placements_ GUARDED_BY(mu_);
|
||||
std::vector<Node*> missing_stateful_placements_ GUARDED_BY(mu_);
|
||||
|
||||
// Map from name to Node for the full graph in placed_.
|
||||
NodeNameToCostIdMap node_name_to_cost_id_map_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
|
75
tensorflow/core/distributed_runtime/worker_cache.h
Normal file
75
tensorflow/core/distributed_runtime/worker_cache.h
Normal file
@ -0,0 +1,75 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h" // for CallOptions
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h" // for BusAdjacency
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
typedef std::function<void(const Status&)> StatusCallback;
|
||||
|
||||
class ChannelCache;
|
||||
class StepStats;
|
||||
class WorkerInterface;
|
||||
|
||||
class WorkerCacheInterface {
|
||||
public:
|
||||
virtual ~WorkerCacheInterface() {}
|
||||
|
||||
// Updates *workers with strings naming the remote worker tasks to
|
||||
// which open channels have been established.
|
||||
virtual void ListWorkers(std::vector<string>* workers) = 0;
|
||||
|
||||
// If "target" names a remote task for which an RPC channel exists
|
||||
// or can be constructed, returns a new WorkerInterface object
|
||||
// wrapping that channel. Ownership passes to the caller.
|
||||
// TODO(tucker): rename this to CreateWorker() or something that
|
||||
// makes it more obvious this is a constructor that transfers
|
||||
// ownership, not a cache lookup.
|
||||
virtual WorkerInterface* CreateWorker(const string& target) = 0;
|
||||
|
||||
// Set *ba with the BusAdjacency of the specified remote device
|
||||
// within its local environment. Returns true if the device bus
|
||||
// affinity was set, using only locally cached data. Returns false
|
||||
// if status data for that device was not available. Never blocks.
|
||||
// TODO(mrry,tucker): Maybe remove.
|
||||
virtual bool GetDeviceBusNonBlocking(const string& device,
|
||||
BusAdjacency* ba) = 0;
|
||||
|
||||
// Set *ba with the BusAdjacency of the specified remote device
|
||||
// within its local environment. Callback gets Status::OK if the
|
||||
// device bus affinity was set.
|
||||
// TODO(mrry,tucker): Maybe remove.
|
||||
virtual void GetDeviceBusAsync(const string& device, BusAdjacency* ba,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
// Start/stop logging activity.
|
||||
virtual void SetLogging(bool active) {}
|
||||
|
||||
// Discard any saved log data.
|
||||
virtual void ClearLogs() {}
|
||||
|
||||
// Return logs for the identified step in *ss. Any returned data will no
|
||||
// longer be stored.
|
||||
virtual bool RetrieveLogs(int64 step_id, StepStats* ss) { return false; }
|
||||
};
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_
|
110
tensorflow/core/distributed_runtime/worker_cache_logger.cc
Normal file
110
tensorflow/core/distributed_runtime/worker_cache_logger.cc
Normal file
@ -0,0 +1,110 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// Maximum number of step_ids for which RPC logs can be maintained.
|
||||
// TODO(mrry): Make this configurable if necessary.
|
||||
const int32 kWorkerCacheLoggerLimit = 1 << 10;
|
||||
} // namespace
|
||||
|
||||
void WorkerCacheLogger::SetLogging(bool v) {
|
||||
mutex_lock l(count_mu_);
|
||||
if (v) {
|
||||
++want_logging_count_;
|
||||
} else {
|
||||
--want_logging_count_;
|
||||
// If RPCs get cancelled, it may be possible for the count
|
||||
// to go negative. This should not be a fatal error, since
|
||||
// logging is non-critical.
|
||||
if (want_logging_count_ < 0) want_logging_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void WorkerCacheLogger::ClearLogs() {
|
||||
mutex_lock l(mu_);
|
||||
ClearLogsWithLock();
|
||||
}
|
||||
|
||||
void WorkerCacheLogger::ClearLogsWithLock() {
|
||||
for (auto& iter : log_map_) {
|
||||
delete iter.second.collector;
|
||||
}
|
||||
log_map_.clear();
|
||||
}
|
||||
|
||||
bool WorkerCacheLogger::RetrieveLogs(int64 step_id, StepStats* ss) {
|
||||
mutex_lock l(mu_);
|
||||
LogMap::iterator iter = log_map_.find(step_id);
|
||||
if (iter != log_map_.end()) {
|
||||
iter->second.collector->Swap(ss);
|
||||
delete iter->second.collector;
|
||||
log_map_.erase(iter);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void WorkerCacheLogger::Save(const string& device, int64 step_id,
|
||||
NodeExecStats* ns) {
|
||||
mutex_lock l(mu_);
|
||||
StepLog* sl = &log_map_[step_id];
|
||||
if (!sl->collector) {
|
||||
sl->collector = new StepStatsCollector(&sl->step_stats);
|
||||
}
|
||||
sl->collector->Save(device, ns);
|
||||
if (log_map_.size() > kWorkerCacheLoggerLimit) {
|
||||
// Something's gone wrong. Just empty the cache.
|
||||
ClearLogsWithLock();
|
||||
}
|
||||
}
|
||||
|
||||
void WorkerCacheLogger::RecordRecvTensor(int64 step_id, int64 start_usecs,
|
||||
int64 end_usecs,
|
||||
const string& tensor_name,
|
||||
const string& src_device,
|
||||
const string& dst_device,
|
||||
int64 bytes) {
|
||||
NodeExecStats* ns = new NodeExecStats;
|
||||
ns->set_node_name("RecvTensor");
|
||||
string byte_string = strings::StrCat("[", bytes, "B] ");
|
||||
if (bytes >= 0.1 * 1048576.0) {
|
||||
byte_string = strings::Printf("[%.1fMB] ", bytes / 1048576.0);
|
||||
}
|
||||
ns->set_timeline_label(strings::StrCat(byte_string, tensor_name, " from ",
|
||||
src_device, " to ", dst_device));
|
||||
ns->set_all_start_micros(start_usecs);
|
||||
ns->set_op_start_rel_micros(0);
|
||||
ns->set_op_end_rel_micros(end_usecs - start_usecs);
|
||||
NodeOutput* no = ns->add_output();
|
||||
no->set_slot(0);
|
||||
// TODO(tucker): Maybe set the dimensions too, but then they'll
|
||||
// need to be passed in.
|
||||
no->mutable_tensor_description()
|
||||
->mutable_allocation_description()
|
||||
->set_requested_bytes(bytes);
|
||||
Save(dst_device, step_id, ns);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
81
tensorflow/core/distributed_runtime/worker_cache_logger.h
Normal file
81
tensorflow/core/distributed_runtime/worker_cache_logger.h
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class StepStatsCollector;
|
||||
|
||||
// WorkerCacheLogger is a thread-safe utility for use by a WorkerCache
|
||||
// to optionally log some selected RPC activity. A single instance
|
||||
// should be owned by a WorkerCache, for use by its RemoteWorker
|
||||
// instances.
|
||||
|
||||
class WorkerCacheLogger {
|
||||
public:
|
||||
// Start/Stop logging activity. This function increments/decrements
|
||||
// a counter so that if two separate steps turn logging on/off,
|
||||
// logging should be on for the union of the durations of both,
|
||||
// regardless of relative timing.
|
||||
void SetLogging(bool v);
|
||||
|
||||
// Discard any saved log data.
|
||||
void ClearLogs();
|
||||
|
||||
// Return logs for the identified step in *ss. Any returned data will no
|
||||
// longer be stored. Returns true iff *ss was modified.
|
||||
bool RetrieveLogs(int64 step_id, StepStats* ss);
|
||||
|
||||
// Return true if there is any outstanding request for logging on
|
||||
// the RPC channels.
|
||||
bool LoggingActive() {
|
||||
mutex_lock l(count_mu_);
|
||||
return want_logging_count_ > 0;
|
||||
}
|
||||
|
||||
// Generates a NodeExecStats record with the given data, and saves for
|
||||
// later retrieval by RetrieveLogs().
|
||||
void RecordRecvTensor(int64 step_id, int64 start_usecs, int64 end_usecs,
|
||||
const string& tensor_name, const string& src_device,
|
||||
const string& dst_device, int64 bytes);
|
||||
|
||||
private:
|
||||
mutex count_mu_;
|
||||
int32 want_logging_count_ GUARDED_BY(count_mu_);
|
||||
|
||||
struct StepLog {
|
||||
StepStats step_stats;
|
||||
StepStatsCollector* collector;
|
||||
};
|
||||
typedef std::unordered_map<int64, StepLog> LogMap;
|
||||
mutex mu_;
|
||||
LogMap log_map_ GUARDED_BY(mu_);
|
||||
|
||||
// Records "ns" in log_map_ under the given device and step.
|
||||
void Save(const string& device, int64 step_id, NodeExecStats* ns);
|
||||
|
||||
void ClearLogsWithLock() EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
};
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_
|
98
tensorflow/core/distributed_runtime/worker_cache_partial.cc
Normal file
98
tensorflow/core/distributed_runtime/worker_cache_partial.cc
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool WorkerCachePartial::GetDeviceBusNonBlocking(const string& device_name,
|
||||
BusAdjacency* ba) {
|
||||
mutex_lock lock(mu_); // could use reader lock
|
||||
const auto& iter = device_status_cache_.find(device_name);
|
||||
if (iter != device_status_cache_.end()) {
|
||||
*ba = iter->second.bus_adjacency();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void WorkerCachePartial::GetDeviceBusAsync(const string& device_name,
|
||||
BusAdjacency* ba,
|
||||
StatusCallback done) {
|
||||
if (!GetDeviceBusNonBlocking(device_name, ba)) {
|
||||
// If cache entry was empty, make one try to fill it by RPC.
|
||||
SchedClosure([this, &device_name, ba, done]() {
|
||||
Status s = RefreshDeviceStatus(device_name);
|
||||
if (s.ok()) {
|
||||
if (!GetDeviceBusNonBlocking(device_name, ba)) {
|
||||
mutex_lock lock(mu_);
|
||||
const auto& iter = device_status_cache_.find(device_name);
|
||||
if (iter == device_status_cache_.end()) {
|
||||
s = errors::Unavailable("No known remote device: ", device_name);
|
||||
} else {
|
||||
s = errors::Internal("Failed to find bus_adjacency for ",
|
||||
device_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
done(s);
|
||||
});
|
||||
return;
|
||||
}
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
Status WorkerCachePartial::RefreshDeviceStatus(const string& device_name) {
|
||||
string task;
|
||||
string device;
|
||||
Status s;
|
||||
if (!DeviceNameUtils::SplitDeviceName(device_name, &task, &device)) {
|
||||
s = errors::InvalidArgument("Bad device name to RefreshDeviceStatus: ",
|
||||
device_name);
|
||||
}
|
||||
std::unique_ptr<WorkerInterface> rwi(CreateWorker(task));
|
||||
if (s.ok() && !rwi.get()) {
|
||||
s = errors::Internal("RefreshDeviceStatus, unknown worker task: ", task);
|
||||
}
|
||||
|
||||
if (s.ok()) {
|
||||
GetStatusRequest req;
|
||||
GetStatusResponse resp;
|
||||
s = rwi->GetStatus(&req, &resp);
|
||||
if (s.ok()) {
|
||||
mutex_lock lock(mu_);
|
||||
for (auto& dev_attr : resp.device_attributes()) {
|
||||
device_status_cache_[dev_attr.name()] = dev_attr;
|
||||
}
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
void WorkerCachePartial::FlushStatusCache() {
|
||||
mutex_lock lock(mu_);
|
||||
device_status_cache_.clear();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
56
tensorflow/core/distributed_runtime/worker_cache_partial.h
Normal file
56
tensorflow/core/distributed_runtime/worker_cache_partial.h
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Implements the part of the interface that caches and returns remote
|
||||
// device status attributes.
|
||||
class WorkerCachePartial : public WorkerCacheInterface {
|
||||
public:
|
||||
bool GetDeviceBusNonBlocking(const string& device, BusAdjacency* ba) override;
|
||||
|
||||
void GetDeviceBusAsync(const string& device, BusAdjacency* ba,
|
||||
StatusCallback) override;
|
||||
|
||||
~WorkerCachePartial() override {}
|
||||
|
||||
// Clear all entries from the DeviceStatus cache.
|
||||
void FlushStatusCache();
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
|
||||
// Initiate a GetStatusAsync to the remote task named by "task", and
|
||||
// update the cache with all the DeviceAttributes reported.
|
||||
Status RefreshDeviceStatus(const string& device_name);
|
||||
|
||||
typedef std::unordered_map<string, DeviceAttributes> StatusMap;
|
||||
StatusMap device_status_cache_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_
|
62
tensorflow/core/distributed_runtime/worker_env.h
Normal file
62
tensorflow/core/distributed_runtime/worker_env.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace thread {
|
||||
class ThreadPool;
|
||||
} // namespace thread
|
||||
|
||||
class DeviceMgr;
|
||||
class Env;
|
||||
class GraphMgr;
|
||||
class RendezvousMgrInterface;
|
||||
class WorkerCacheInterface;
|
||||
|
||||
// The worker environment class, which holds a bag of pointers to
|
||||
// per-worker singletons.
|
||||
//
|
||||
// WorkerEnv does not own its member pointers.
|
||||
struct WorkerEnv {
|
||||
Env* env = nullptr;
|
||||
|
||||
// The name of the worker. E.g., /job:mnist/replica:1/task:0.
|
||||
string worker_name;
|
||||
|
||||
// Object from which WorkerInterface instances can be obtained.
|
||||
WorkerCacheInterface* worker_cache = nullptr;
|
||||
|
||||
// device_mgr manages local devices (cpu and gpu). The WorkerService
|
||||
// is the network interface for managed devices.
|
||||
DeviceMgr* device_mgr = nullptr;
|
||||
|
||||
// graph_mgr keeps track of registered graphs of this worker.
|
||||
GraphMgr* graph_mgr = nullptr;
|
||||
|
||||
// A set of rendezvous keyed by step ids.
|
||||
RendezvousMgrInterface* rendezvous_mgr = nullptr;
|
||||
|
||||
// A pool of threads for scheduling compute work.
|
||||
thread::ThreadPool* compute_pool = nullptr;
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
|
129
tensorflow/core/distributed_runtime/worker_interface.h
Normal file
129
tensorflow/core/distributed_runtime/worker_interface.h
Normal file
@ -0,0 +1,129 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Status callback.
|
||||
typedef std::function<void(const Status&)> StatusCallback;
|
||||
|
||||
// Allocator callback for out-of-band transfers.
|
||||
class TensorShape;
|
||||
typedef std::function<void*(size_t, const DataType&, const TensorShape&)>
|
||||
TensorBufAllocator;
|
||||
|
||||
// Interface for talking with the TensorFlow Worker service.
|
||||
class WorkerInterface {
|
||||
public:
|
||||
virtual ~WorkerInterface() {}
|
||||
|
||||
virtual void GetStatusAsync(const GetStatusRequest* request,
|
||||
GetStatusResponse* response,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
|
||||
RegisterGraphResponse* response,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
|
||||
DeregisterGraphResponse* response,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
|
||||
RunGraphResponse* response,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||
CleanupGraphResponse* response,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
virtual void CleanupAllAsync(const CleanupAllRequest* request,
|
||||
CleanupAllResponse* response,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
virtual void RecvTensorAsync(CallOptions* opts,
|
||||
const RecvTensorRequest* request,
|
||||
RecvTensorResponse* response,
|
||||
TensorBufAllocator allocator,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
virtual void LoggingAsync(const LoggingRequest* request,
|
||||
LoggingResponse* response, StatusCallback done) = 0;
|
||||
|
||||
virtual void TracingAsync(const TracingRequest* request,
|
||||
TracingResponse* response, StatusCallback done) = 0;
|
||||
|
||||
Status GetStatus(const GetStatusRequest* request,
|
||||
GetStatusResponse* response) {
|
||||
return CallAndWait(&ME::GetStatusAsync, request, response);
|
||||
}
|
||||
|
||||
Status RegisterGraph(const RegisterGraphRequest* request,
|
||||
RegisterGraphResponse* response) {
|
||||
return CallAndWait(&ME::RegisterGraphAsync, request, response);
|
||||
}
|
||||
|
||||
Status DeregisterGraph(const DeregisterGraphRequest* request,
|
||||
DeregisterGraphResponse* response) {
|
||||
return CallAndWait(&ME::DeregisterGraphAsync, request, response);
|
||||
}
|
||||
|
||||
Status CleanupGraph(const CleanupGraphRequest* request,
|
||||
CleanupGraphResponse* response) {
|
||||
return CallAndWait(&ME::CleanupGraphAsync, request, response);
|
||||
}
|
||||
|
||||
Status CleanupAll(const CleanupAllRequest* request,
|
||||
CleanupAllResponse* response) {
|
||||
return CallAndWait(&ME::CleanupAllAsync, request, response);
|
||||
}
|
||||
|
||||
Status Logging(const LoggingRequest* request, LoggingResponse* response) {
|
||||
return CallAndWait(&ME::LoggingAsync, request, response);
|
||||
}
|
||||
|
||||
Status Tracing(const TracingRequest* request, TracingResponse* response) {
|
||||
return CallAndWait(&ME::TracingAsync, request, response);
|
||||
}
|
||||
|
||||
private:
|
||||
typedef WorkerInterface ME;
|
||||
|
||||
template <typename Method, typename Req, typename Resp>
|
||||
Status CallAndWait(Method func, const Req* req, Resp* resp) {
|
||||
Status ret;
|
||||
Notification n;
|
||||
(this->*func)(req, resp, [&ret, &n](const Status& s) {
|
||||
ret = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
|
@ -65,7 +65,7 @@ Status LoadLibrary(const char* library_filename, void** result,
|
||||
string str;
|
||||
GetOpList(&str);
|
||||
char* str_buf = reinterpret_cast<char*>(operator new(str.length()));
|
||||
strncpy(str_buf, str.data(), str.length());
|
||||
memcpy(str_buf, str.data(), str.length());
|
||||
*buf = str_buf;
|
||||
*len = str.length();
|
||||
|
||||
|
@ -25,37 +25,58 @@ def tf_deps(deps, suffix):
|
||||
|
||||
return tf_deps
|
||||
|
||||
def tf_proto_library(name, srcs = [], has_services = False,
|
||||
deps = [], visibility = [], testonly = 0,
|
||||
cc_api_version = 2, go_api_version = 2,
|
||||
java_api_version = 2,
|
||||
py_api_version = 2):
|
||||
def tf_proto_library_cc(name, srcs = [], has_services = None,
|
||||
deps = [], visibility = [], testonly = 0,
|
||||
cc_libs = [],
|
||||
cc_stubby_versions = None,
|
||||
cc_grpc_version = None,
|
||||
cc_api_version = 2, go_api_version = 2,
|
||||
java_api_version = 2,
|
||||
py_api_version = 2):
|
||||
native.filegroup(name=name + "_proto_srcs",
|
||||
srcs=srcs + tf_deps(deps, "_proto_srcs"),
|
||||
testonly=testonly,)
|
||||
|
||||
use_grpc_plugin = None
|
||||
if cc_grpc_version:
|
||||
use_grpc_plugin = True
|
||||
cc_proto_library(name=name + "_cc",
|
||||
srcs=srcs + tf_deps(deps, "_proto_srcs"),
|
||||
deps=deps + ["//google/protobuf:cc_wkt_protos"],
|
||||
cc_libs = ["//google/protobuf:protobuf"],
|
||||
cc_libs = cc_libs + ["//google/protobuf:protobuf"],
|
||||
use_grpc_plugin = use_grpc_plugin,
|
||||
testonly=testonly,
|
||||
visibility=visibility,)
|
||||
|
||||
py_proto_library(name=name + "_py",
|
||||
srcs=srcs + tf_deps(deps, "_proto_srcs"),
|
||||
srcs_version="PY2AND3",
|
||||
deps=deps + ["//google/protobuf:protobuf_python"],
|
||||
testonly=testonly,
|
||||
visibility=visibility,)
|
||||
|
||||
def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0):
|
||||
def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0,
|
||||
srcs_version="PY2AND3"):
|
||||
py_proto_library(name = name + "_py",
|
||||
srcs = srcs,
|
||||
srcs_version = "PY2AND3",
|
||||
srcs_version = srcs_version,
|
||||
deps = deps,
|
||||
visibility = visibility,
|
||||
testonly = testonly)
|
||||
|
||||
def tf_proto_library(name, srcs = [], has_services = None,
|
||||
deps = [], visibility = [], testonly = 0,
|
||||
cc_libs = [],
|
||||
cc_api_version = 2, go_api_version = 2,
|
||||
java_api_version = 2,
|
||||
py_api_version = 2):
|
||||
tf_proto_library_cc(name=name,
|
||||
srcs=srcs + tf_deps(deps, "_proto_srcs"),
|
||||
deps=deps,
|
||||
cc_libs=cc_libs,
|
||||
testonly=testonly,
|
||||
visibility=visibility,)
|
||||
|
||||
tf_proto_library_py(name=name,
|
||||
srcs=srcs + tf_deps(deps, "_proto_srcs"),
|
||||
srcs_version="PY2AND3",
|
||||
deps=deps + ["//google/protobuf:protobuf_python"],
|
||||
testonly=testonly,
|
||||
visibility=visibility,)
|
||||
|
||||
def tf_additional_lib_srcs():
|
||||
return [
|
||||
"platform/default/*.h",
|
||||
|
190
tensorflow/core/protobuf/master.proto
Normal file
190
tensorflow/core/protobuf/master.proto
Normal file
@ -0,0 +1,190 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
// option cc_enable_arenas = true;
|
||||
option java_outer_classname = "DistributedRuntimeProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.distruntime";
|
||||
|
||||
import "tensorflow/core/framework/config.proto";
|
||||
import "tensorflow/core/framework/device_attributes.proto";
|
||||
import "tensorflow/core/framework/graph.proto";
|
||||
import "tensorflow/core/framework/tensor.proto";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// CreateSession method request/response protos.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message CreateSessionRequest {
|
||||
// The initial graph definition.
|
||||
GraphDef graph_def = 1;
|
||||
|
||||
// Configuration options.
|
||||
ConfigProto config = 2;
|
||||
}
|
||||
|
||||
message CreateSessionResponse {
|
||||
// The session handle to be used in subsequent calls for the created session.
|
||||
//
|
||||
// The client must arrange to call CloseSession with this returned
|
||||
// session handle to close the session.
|
||||
string session_handle = 1;
|
||||
|
||||
// The initial version number for the graph, to be used in the next call
|
||||
// to ExtendSession.
|
||||
int64 graph_version = 2;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// ExtendSession method request/response protos.
|
||||
//
|
||||
// The "graph_def" specifies a set of nodes to be added to the session's graph.
|
||||
//
|
||||
// A typical "graph_def" will contain:
|
||||
//
|
||||
// * Zero or more new nodes with names that do not exist in the server-side
|
||||
// graph. These will be added to the graph.
|
||||
//
|
||||
// PRECONDITION: The server-side current version is req.current_version.
|
||||
// None of the names in req.graph_def appeared in previous successful calls to
|
||||
// CreateSession or ExtendSession with the same session_handle.
|
||||
// POSTCONDITION: The server-side current version is resp.new_version.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message ExtendSessionRequest {
|
||||
// REQUIRED: session_handle must be returned by a CreateSession call
|
||||
// to the same master service.
|
||||
string session_handle = 1;
|
||||
|
||||
// REQUIRED: The nodes to be added to the session's graph. If any node has
|
||||
// the same name as an existing node, the operation will fail with
|
||||
// ILLEGAL_ARGUMENT.
|
||||
GraphDef graph_def = 2;
|
||||
|
||||
// REQUIRED: The version number of the graph to be extended. This will be
|
||||
// tested against the current server-side version number, and the operation
|
||||
// will fail with FAILED_PRECONDITION if they do not match.
|
||||
int64 current_graph_version = 3;
|
||||
}
|
||||
|
||||
message ExtendSessionResponse {
|
||||
// TODO(mrry): Return something about the operation?
|
||||
|
||||
// The new version number for the extended graph, to be used in the next call
|
||||
// to ExtendSession.
|
||||
int64 new_graph_version = 4;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// RunStep method request/response protos.
|
||||
//
|
||||
// The caller should provide the feeds needed by the graph and specify
|
||||
// what nodes should be fetched.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A pair of tensor name and tensor values.
|
||||
message NamedTensorProto {
|
||||
// Name of the tensor.
|
||||
string name = 1;
|
||||
|
||||
// The client can populate a TensorProto using a tensorflow::Tensor`, or
|
||||
// directly using the protobuf field accessors.
|
||||
//
|
||||
// The client specifies whether the returned tensor values should be
|
||||
// filled tensor fields (float_val, int_val, etc.) or encoded in a
|
||||
// compact form in tensor.tensor_content.
|
||||
TensorProto tensor = 2;
|
||||
}
|
||||
|
||||
message RunStepRequest {
|
||||
// REQUIRED: session_handle must be returned by a CreateSession call
|
||||
// to the same master service.
|
||||
string session_handle = 1;
|
||||
|
||||
// Tensors to be fed in the step. Each feed is a named tensor.
|
||||
repeated NamedTensorProto feed = 2;
|
||||
|
||||
// Fetches. A list of tensor names. The caller expects a tensor to
|
||||
// be returned for each fetch[i] (see RunStepResponse.tensor). The
|
||||
// order of specified fetches does not change the execution order.
|
||||
repeated string fetch = 3;
|
||||
|
||||
// Target Nodes. A list of node names. The named nodes will be run
|
||||
// to but their outputs will not be fetched.
|
||||
repeated string target = 4;
|
||||
}
|
||||
|
||||
message RunStepResponse {
|
||||
// NOTE: The order of the returned tensors may or may not match
|
||||
// the fetch order specified in RunStepRequest.
|
||||
repeated NamedTensorProto tensor = 1;
|
||||
|
||||
// TODO(mrry): Optionally aggregate StepStats in some form here.
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// CloseSession method request/response protos.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message CloseSessionRequest {
|
||||
// REQUIRED: session_handle must be returned by a CreateSession call
|
||||
// to the same master service.
|
||||
string session_handle = 1;
|
||||
}
|
||||
|
||||
message CloseSessionResponse {
|
||||
}
|
||||
|
||||
message ResetRequest {
|
||||
// A list of container names, which may be empty.
|
||||
//
|
||||
// If 'container' is not empty, releases resoures in the given
|
||||
// containers in all devices.
|
||||
//
|
||||
// If 'container' is empty, releases resources in the default
|
||||
// container in all devices.
|
||||
repeated string container = 1;
|
||||
}
|
||||
|
||||
message ResetResponse {
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// ListDevices method request/response protos.
|
||||
//
|
||||
// Returns information about the TensorFlow devices that are available
|
||||
// to this master.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message ListDevicesRequest {
|
||||
}
|
||||
|
||||
message ListDevicesResponse {
|
||||
repeated DeviceAttributes local_device = 1;
|
||||
repeated DeviceAttributes remote_device = 2;
|
||||
}
|
105
tensorflow/core/protobuf/master_service.proto
Normal file
105
tensorflow/core/protobuf/master_service.proto
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.grpc;
|
||||
option java_outer_classname = "MasterServiceProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.distruntime";
|
||||
|
||||
import "tensorflow/core/protobuf/master.proto";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// MasterService defines a TensorFlow service with which a client can
|
||||
// interact to execute a distributed TensorFlow computation.
|
||||
//
|
||||
// A master service keeps track of multiple "master sessions". Each
|
||||
// session encapsulates a computation graph and its associated state,
|
||||
// and typically corresponds to a single "client session" (e.g. a
|
||||
// `tensorflow::Session` instance).
|
||||
//
|
||||
// A session is responsible for the following:
|
||||
// * assigning each node to a device (locally or remotely) using a
|
||||
// placement algorithm. This may make decisions based on collected
|
||||
// statistics from the workers in the system (e.g., memory usage,
|
||||
// bandwidth consumption, etc.)
|
||||
//
|
||||
// * inserting intermediate nodes and edges to support cross-device
|
||||
// and cross-process data flows and resource management.
|
||||
//
|
||||
// * issuing commands to workers to execute the subgraphs associated
|
||||
// with those workers.
|
||||
//
|
||||
// Typically, a client carries out an iterative computation
|
||||
// (e.g. training) by invoking RPCs against the master in a
|
||||
// client-side loop. The client first creates a client session that
|
||||
// connects to a particular master (using gRPC for example). The
|
||||
// master creates a corresponding master session that is hosted on
|
||||
// the master and caches state between the client's invocations.
|
||||
//
|
||||
// After the session is established, the master returns an opaque
|
||||
// handle to the client that can be used to associate the client and
|
||||
// master sessions.
|
||||
//
|
||||
// The client may send an initial graph to the master in the
|
||||
// CreateSession call, and add nodes to the graph using ExtendSession.
|
||||
//
|
||||
// The most frequent operation a master is "RunStep", which implements
|
||||
// the `Session::Run()` API. It supports feeding in arguments,
|
||||
// executing a dataflow computation, and fetching arguments.
|
||||
//
|
||||
// Finally, when the client no longer needs the session, it should
|
||||
// close the session by invoking CloseSession, which allows the master
|
||||
// to reclaim resources associated with the session. The master may
|
||||
// implement a garbage collection scheme that closes sessions that
|
||||
// have been inactive for some time.
|
||||
//
|
||||
// For example, the following pseudo-code illustrates how a client
|
||||
// interacts with a master:
|
||||
//
|
||||
// stub = NewStub("/job:mnist/replica:0/task:0")
|
||||
// {handle} = stub->CreateSession({graph_def})
|
||||
// do {
|
||||
// stub->RunStep({handle, {feeds}, {fetches}})
|
||||
// // The client can evaluate a predicate locally, based on the
|
||||
// // result of `fetches`, to determine whether to terminate. For
|
||||
// // example, it might fetch the loss and evaluate whether it is less
|
||||
// // than some threshold.
|
||||
// } whlie (!should_stop({fetches}));
|
||||
// stub->CloseSession({handle})
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
service MasterService {
|
||||
// Creates a session.
|
||||
rpc CreateSession(CreateSessionRequest) returns (CreateSessionResponse);
|
||||
|
||||
// Extends a session.
|
||||
rpc ExtendSession(ExtendSessionRequest) returns (ExtendSessionResponse);
|
||||
|
||||
// Drives the graph computation.
|
||||
rpc RunStep(RunStepRequest) returns (RunStepResponse);
|
||||
|
||||
// Closes a session.
|
||||
rpc CloseSession(CloseSessionRequest) returns (CloseSessionResponse);
|
||||
|
||||
// List the devices usable by the master.
|
||||
rpc ListDevices(ListDevicesRequest) returns (ListDevicesResponse);
|
||||
|
||||
// Close all existing sessions.
|
||||
rpc Reset(ResetRequest) returns (ResetResponse);
|
||||
}
|
311
tensorflow/core/protobuf/worker.proto
Normal file
311
tensorflow/core/protobuf/worker.proto
Normal file
@ -0,0 +1,311 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
// option cc_enable_arenas = true;
|
||||
option java_outer_classname = "WorkerProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.distruntime";
|
||||
|
||||
import "google/protobuf/any.proto";
|
||||
import "tensorflow/core/framework/config.proto";
|
||||
import "tensorflow/core/framework/step_stats.proto";
|
||||
import "tensorflow/core/framework/device_attributes.proto";
|
||||
import "tensorflow/core/framework/graph.proto";
|
||||
import "tensorflow/core/framework/tensor.proto";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// GetStatus method request/response messages
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message GetStatusRequest {
|
||||
}
|
||||
|
||||
message GetStatusResponse {
|
||||
repeated DeviceAttributes device_attributes = 1;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// RegisterGraph method request/response messages
|
||||
//
|
||||
// For each session, after the master placed every node on a device,
|
||||
// it partitions the whole graph into many subgraphs. All the nodes in
|
||||
// a subgraph were in the same worker, but potentially on many devices
|
||||
// owned by that worker (e.g. cpu0, plus gpu0, gpu1, ..., gpu7). The
|
||||
// master registers subgraphs for a worker before running any steps. A
|
||||
// successful registration returns a graph handle to be used in latter
|
||||
// RunGraph requests.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message RegisterGraphRequest {
|
||||
// Subgraphs are scoped within one session.
|
||||
string session_handle = 1;
|
||||
|
||||
// "graph_def" has the subgraph of nodes for this worker, with each node
|
||||
// having its device_name filled in.
|
||||
GraphDef graph_def = 2;
|
||||
|
||||
// True iff the graph (before partitioning) contains control flow nodes.
|
||||
//
|
||||
// As of 01/11/2015, this is no longer set by clients.
|
||||
bool has_control_flow = 3 [deprecated = true];
|
||||
|
||||
// Configuration options for the session in which this graph was created.
|
||||
GraphOptions graph_options = 4;
|
||||
}
|
||||
|
||||
message RegisterGraphResponse {
|
||||
// If the registration succeeds, returns an opaque graph_handle to
|
||||
// the master. The master calls RunGraph with graph_handle to
|
||||
// compute different steps.
|
||||
string graph_handle = 1;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// DeregisterGraph method request/response messages
|
||||
//
|
||||
// The master deregisters the given graph_handle when the graph is no
|
||||
// longer needed (e.g., the overall graph is re-scheduled and nodes
|
||||
// are re-placed).
|
||||
//
|
||||
// The worker deregisters a graph_handle automatically according to on
|
||||
// a TTL-base policy in case of master restarts.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message DeregisterGraphRequest {
|
||||
// REQUIRED: graph_handle must be returned by a RegisterGraph call
|
||||
// to the same WorkerService.
|
||||
string graph_handle = 1;
|
||||
}
|
||||
|
||||
message DeregisterGraphResponse {
|
||||
// TODO(mrry): Optionally add summary stats for the graph.
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// CleanupAll method request/response messages
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message CleanupAllRequest {
|
||||
// A list of container names.
|
||||
//
|
||||
// If 'container' is not empty, releases resoures in the given
|
||||
// containers in all devices.
|
||||
//
|
||||
// If 'container' is empty, releases resources in the default
|
||||
// container in all devices.
|
||||
repeated string container = 1;
|
||||
}
|
||||
|
||||
message CleanupAllResponse {
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// RunGraph request / response messages
|
||||
//
|
||||
// The worker executes all subgraphs registered under graph_handle.
|
||||
// RunGraph returns after the execution finishes or an error is
|
||||
// encountered.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A pair of tensor name and tensor values.
|
||||
message NamedTensor {
|
||||
// The name of the named tensor.
|
||||
string key = 1;
|
||||
|
||||
// The value of the named tensor.
|
||||
TensorProto val = 2;
|
||||
}
|
||||
|
||||
// Options specific to the execution of a single step.
|
||||
message ExecutorOpts {
|
||||
bool record_costs = 1;
|
||||
bool record_timeline = 3;
|
||||
};
|
||||
|
||||
message RunGraphRequest {
|
||||
// REQUIRED: graph_handle must be returned by a RegisterGraph call
|
||||
// to the same WorkerService.
|
||||
string graph_handle = 1;
|
||||
|
||||
// A unique ID to distinguish different runs of the same graph.
|
||||
//
|
||||
// The master generates a global unique `step_id` to dinstinguish
|
||||
// different runs of the graph computation. Subgraphs communicate
|
||||
// (e.g., send/recv ops) with each other using `step_id` to
|
||||
// distinguish tensors generated by different runs.
|
||||
int64 step_id = 2;
|
||||
|
||||
// Options for this step.
|
||||
ExecutorOpts exec_opts = 5;
|
||||
|
||||
// Runs the graph.
|
||||
//
|
||||
// Sends the tensors in "send" into the graph before the run and
|
||||
// fetches the keys into `RunGraphResponse.recv` after the run.
|
||||
repeated NamedTensor send = 3;
|
||||
repeated string recv_key = 4;
|
||||
}
|
||||
|
||||
message RunGraphResponse {
|
||||
// A list of tensors corresponding to those requested by
|
||||
// `RunGraphRequest.recv_key`.
|
||||
repeated NamedTensor recv = 1;
|
||||
|
||||
// If the request asked for execution stats, these are returned here.
|
||||
StepStats step_stats = 2;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// CleanupGraph method request/response messages
|
||||
//
|
||||
// After the master receives RunGraph responses from all workers, the
|
||||
// master instructs every worker to cleanup any remaining state of a
|
||||
// step (e.g. tensors buffered by a `Send` op but not picked up by
|
||||
// other workers). The master does not necessarily need to wait for
|
||||
// completion of CleanupGraph calls.
|
||||
//
|
||||
// Workers should cleanup step states automatically according to a
|
||||
// TTL-based policy in case of master restarts.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message CleanupGraphRequest {
|
||||
int64 step_id = 1;
|
||||
}
|
||||
|
||||
message CleanupGraphResponse {
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// RecvTensor method request/response messages
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message RecvTensorRequest {
|
||||
// The step in which the tensor will be produced.
|
||||
//
|
||||
// REQUIRED: This must eventually correspond to the `step_id` passed
|
||||
// into a RunGraph call on the same WorkerService.
|
||||
int64 step_id = 1;
|
||||
|
||||
// A key that identifies the tensor to be received.
|
||||
string rendezvous_key = 2;
|
||||
|
||||
// If true, use an out-of-band DMA mechanism to transfer the
|
||||
// received tensor.
|
||||
bool dma_ok = 3;
|
||||
// NIC bus preference on the request originator side
|
||||
BusAdjacency client_bus_adjacency = 4;
|
||||
// NIC bus preference on the request receiver side
|
||||
BusAdjacency server_bus_adjacency = 5;
|
||||
}
|
||||
|
||||
message RecvTensorResponse {
|
||||
// The tensor as a proto.
|
||||
TensorProto tensor = 1;
|
||||
|
||||
// If true, this tensor was the output of a dead node, and the
|
||||
// content is invalid.
|
||||
bool is_dead = 2;
|
||||
|
||||
// The time at which tensor was available and started to be returned.
|
||||
int64 send_start_micros = 3;
|
||||
|
||||
// Optional additional information about how to receive the tensor,
|
||||
// in the event that `RecvTensorRequest.dma_ok` was true.
|
||||
google.protobuf.Any transport_options = 4;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Logging method request/response messages
|
||||
//
|
||||
// NOTE(mrry): This feature is not supported in the open-source
|
||||
// version, and these messages are expected to change.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Out-of-band request to begin or end logging, or
|
||||
// to retrieve logs for particular steps.
|
||||
message LoggingRequest {
|
||||
// If true, RPC logging will be activated.
|
||||
bool rpc_logging = 1;
|
||||
|
||||
// If true, discard any saved logging data (for all steps).
|
||||
bool clear = 2;
|
||||
|
||||
// When set, requests all saved log data pertaining to the step.
|
||||
// Any log data retrieved is eliminated from the store and cannot be
|
||||
// retrieved again.
|
||||
repeated int64 fetch_step_id = 3;
|
||||
}
|
||||
|
||||
message LabeledStepStats {
|
||||
int64 step_id = 1;
|
||||
StepStats step_stats = 2;
|
||||
}
|
||||
|
||||
message LoggingResponse {
|
||||
repeated LabeledStepStats step = 1;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Tracing method request/response messages
|
||||
//
|
||||
// NOTE(mrry): This feature is not supported in the open-source
|
||||
// version, and these messages are expected to change.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message TraceOpts {
|
||||
// Length of the trace to be taken, in seconds.
|
||||
double duration = 1;
|
||||
// If true, capture step profile locally in each worker. Currently
|
||||
// unimplemented.
|
||||
bool use_step_profiler = 2;
|
||||
// If true, capture kernel events from each worker.
|
||||
bool use_kernel_profiler = 3;
|
||||
// If true, capture extended profiling events from TensorFlow process.
|
||||
bool use_extended_profiler = 4;
|
||||
// If true, capture GPU profiling events locally on each
|
||||
// machine. Currently unimplemented.
|
||||
bool use_gpu_profiler = 5;
|
||||
// If true, collect sampled profile events. Currently unimplemented.
|
||||
bool use_sample_profiler = 6;
|
||||
}
|
||||
|
||||
// Out-of-band request to configure distributed tracing.
|
||||
message TracingRequest {
|
||||
TraceOpts options = 1;
|
||||
}
|
||||
|
||||
message TracingResponse {
|
||||
}
|
67
tensorflow/core/protobuf/worker_service.proto
Normal file
67
tensorflow/core/protobuf/worker_service.proto
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.grpc;
|
||||
option java_outer_classname = "WorkerServiceProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.distruntime";
|
||||
|
||||
import "tensorflow/core/protobuf/worker.proto";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// WorkerService defines a TensorFlow service that executes dataflow
|
||||
// graphs on a set of local devices, on behalf of a MasterService.
|
||||
//
|
||||
// A worker service keeps track of multiple "registered graphs". Each
|
||||
// registered graph is a subgraph of a client's graph, corresponding to
|
||||
// only the nodes that should execute on this worker (and any
|
||||
// additional nodes necessary for inter-process communication using
|
||||
// the `RecvTensor` method).
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
service WorkerService {
|
||||
// See worker.proto for details.
|
||||
rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);
|
||||
|
||||
// See worker.proto for details.
|
||||
rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse);
|
||||
|
||||
// See worker.proto for details.
|
||||
rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse);
|
||||
|
||||
// See worker.proto for details.
|
||||
rpc RunGraph(RunGraphRequest) returns (RunGraphResponse);
|
||||
|
||||
// See worker.proto for details.
|
||||
rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse);
|
||||
|
||||
// See worker.proto for details.
|
||||
rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse);
|
||||
|
||||
// See worker.proto for details.
|
||||
rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {
|
||||
// RecvTensor Method
|
||||
}
|
||||
|
||||
// See worker.proto for details.
|
||||
rpc Logging(LoggingRequest) returns (LoggingResponse);
|
||||
|
||||
// See worker.proto for details.
|
||||
rpc Tracing(TracingRequest) returns (TracingResponse);
|
||||
}
|
@ -920,6 +920,7 @@ tf_py_wrap_cc(
|
||||
":py_record_writer_lib",
|
||||
":python_op_gen",
|
||||
":tf_session_helper",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
)
|
||||
|
@ -750,8 +750,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(c_list[1], out[1].decode('utf-8'))
|
||||
|
||||
def testInvalidTargetFails(self):
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
'Registered factories are {DIRECT_SESSION}'):
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError,
|
||||
'No session factory registered for the given session options.'):
|
||||
session.Session('INVALID_TARGET')
|
||||
|
||||
def testFetchByNameDifferentStringTypes(self):
|
||||
|
@ -135,7 +135,7 @@ have varying scale, and to aid generalization.
|
||||
@@l2_normalize
|
||||
@@local_response_normalization
|
||||
@@sufficient_statistics
|
||||
@@aggregate_moments
|
||||
@@normalize_moments
|
||||
@@moments
|
||||
|
||||
## Losses
|
||||
@ -561,7 +561,7 @@ def sufficient_statistics(x, axes, shift=True, keep_dims=False, name=None):
|
||||
return counts, m_ss, v_ss, shift_value
|
||||
|
||||
|
||||
def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None):
|
||||
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
|
||||
"""Calculate the mean and variance of based on the sufficient statistics.
|
||||
|
||||
Args:
|
||||
@ -577,7 +577,7 @@ def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None):
|
||||
Returns:
|
||||
Two `Tensor` objects: `mean` and `variance`.
|
||||
"""
|
||||
with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "aggregate"):
|
||||
with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "normalize"):
|
||||
divisor = math_ops.inv(counts, name="divisor")
|
||||
if shift is not None:
|
||||
shifted_mean = math_ops.mul(mean_ss, divisor, name="shifted_mean")
|
||||
@ -620,7 +620,7 @@ def moments(x, axes, name=None, keep_dims=False):
|
||||
axes,
|
||||
keep_dims=keep_dims,
|
||||
name=name)
|
||||
return aggregate_moments(counts, m_ss, v_ss, shift, name=name)
|
||||
return normalize_moments(counts, m_ss, v_ss, shift, name=name)
|
||||
|
||||
|
||||
def batch_normalization(x,
|
||||
|
@ -826,19 +826,19 @@ class SufficientStatisticsTest(tf.test.TestCase):
|
||||
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
|
||||
|
||||
|
||||
class AggregateMomentsTest(tf.test.TestCase):
|
||||
class NormalizeMomentsTest(tf.test.TestCase):
|
||||
|
||||
def _npAggregateMoments(self, counts, mean_ss, variance_ss, shift):
|
||||
def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
|
||||
mean = mean_ss / counts
|
||||
variance = variance_ss / counts - mean * mean
|
||||
if shift is not None:
|
||||
mean += shift
|
||||
return mean, variance
|
||||
|
||||
def _opAggregateMoments(self, counts, mean_ss, variance_ss, shift):
|
||||
return tf.nn.aggregate_moments(counts, mean_ss, variance_ss, shift)
|
||||
def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
|
||||
return tf.nn.normalize_moments(counts, mean_ss, variance_ss, shift)
|
||||
|
||||
def _testAggregateMoments(self, shape, shift):
|
||||
def _testNormalizeMoments(self, shape, shift):
|
||||
counts = np.ones([1]).astype(np.float32)
|
||||
mean_ss = np.random.random_sample(shape).astype(np.float32)
|
||||
variance_ss = np.random.random_sample(shape).astype(np.float32)
|
||||
@ -847,7 +847,7 @@ class AggregateMomentsTest(tf.test.TestCase):
|
||||
shift_v = np.random.random_sample(shape).astype(np.float32)
|
||||
else:
|
||||
shift_v = None
|
||||
npm, npv = self._npAggregateMoments(counts, mean_ss, variance_ss, shift_v)
|
||||
npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v)
|
||||
for use_gpu in [True, False]:
|
||||
with self.test_session(use_gpu=use_gpu) as sess:
|
||||
tf_counts = tf.constant(counts, name="counts")
|
||||
@ -857,16 +857,16 @@ class AggregateMomentsTest(tf.test.TestCase):
|
||||
tf_shift_v = tf.constant(shift_v, name="shift")
|
||||
else:
|
||||
tf_shift_v = None
|
||||
opm, opv = self._opAggregateMoments(tf_counts, tf_mean_ss,
|
||||
opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss,
|
||||
tf_variance_ss, tf_shift_v)
|
||||
tfm, tfv = sess.run([opm, opv])
|
||||
self.assertAllClose(npm, tfm, atol=0.000001)
|
||||
self.assertAllClose(npv, tfv, atol=0.000001)
|
||||
|
||||
def testAggregateMoments(self):
|
||||
def testNormalizeMoments(self):
|
||||
for shift in [True, False]:
|
||||
self._testAggregateMoments([3], shift)
|
||||
self._testAggregateMoments([2, 3], shift)
|
||||
self._testNormalizeMoments([3], shift)
|
||||
self._testNormalizeMoments([2, 3], shift)
|
||||
|
||||
|
||||
class MomentsTest(tf.test.TestCase):
|
||||
@ -971,15 +971,15 @@ class MomentsTest(tf.test.TestCase):
|
||||
"""Make sure the output names are stable."""
|
||||
with self.test_session():
|
||||
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False)
|
||||
self.assertEquals(mean.op.name, "moments/aggregate/mean")
|
||||
self.assertEquals(var.op.name, "moments/aggregate/variance")
|
||||
self.assertEquals(mean.op.name, "moments/normalize/mean")
|
||||
self.assertEquals(var.op.name, "moments/normalize/variance")
|
||||
|
||||
def testOutputNamesKeep(self):
|
||||
"""Make sure the output names are stable."""
|
||||
with self.test_session():
|
||||
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True)
|
||||
self.assertEquals(mean.op.name, "moments/aggregate/mean")
|
||||
self.assertEquals(var.op.name, "moments/aggregate/variance")
|
||||
self.assertEquals(mean.op.name, "moments/normalize/mean")
|
||||
self.assertEquals(var.op.name, "moments/normalize/variance")
|
||||
|
||||
|
||||
class ComputeSampledLogitsTest(tf.test.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user