Initial version of the open-source distributed TensorFlow runtime.

This includes a gRPC server (grpc_tensorflow_server) that can serve as both
the master of a distributed TensorFlow computation, and an individual worker
in the computation. The GrpcSession class is included to allow client programs
(including Python clients) to interact with a server.

See tensorflow/core/distributed_runtime/README.md for usage instructions.

This change partially addresses issue .
Change: 115634191
This commit is contained in:
Derek Murray 2016-02-25 20:10:09 -08:00 committed by TensorFlower Gardener
parent d27da251bc
commit 00986d48bb
79 changed files with 11222 additions and 37 deletions

View File

@ -15,6 +15,37 @@
load("//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace()
# grpc expects //external:protobuf_clib and //external:protobuf_compiler
# to point to the protobuf's compiler library.
bind(
name = "protobuf_clib",
actual = "//google/protobuf:protoc_lib",
)
bind(
name = "protobuf_compiler",
actual = "//google/protobuf:protoc_lib",
)
git_repository(
name = "grpc",
commit = "73979f4",
init_submodules = True,
remote = "https://github.com/grpc/grpc.git",
)
# protobuf expects //external:grpc_cpp_plugin to point to grpc's
# C++ plugin code generator.
bind(
name = "grpc_cpp_plugin",
actual = "@grpc//:grpc_cpp_plugin",
)
bind(
name = "grpc_lib",
actual = "@grpc//:grpc++_unsecure",
)
# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT
new_git_repository(

View File

@ -61,6 +61,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gpu_kernel_library")
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_proto_library",
"tf_proto_library_cc",
"tf_additional_lib_srcs",
"tf_additional_stream_executor_srcs",
"tf_additional_test_deps",
@ -77,7 +78,15 @@ load(
tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
srcs = glob(
["**/*.proto"],
exclude = [
"protobuf/worker.proto",
"protobuf/worker_service.proto",
"protobuf/master.proto",
"protobuf/master_service.proto",
],
),
cc_api_version = 2,
go_api_version = 2,
java_api_version = 2,
@ -85,6 +94,54 @@ tf_proto_library(
visibility = ["//visibility:public"],
)
tf_proto_library_cc(
name = "worker_proto",
srcs = ["protobuf/worker.proto"],
cc_api_version = 2,
cc_libs = [":protos_all_cc"],
visibility = [
"//tensorflow:internal",
],
)
tf_proto_library_cc(
name = "worker_service_proto",
srcs = ["protobuf/worker_service.proto"],
has_services = 1,
cc_api_version = 2,
cc_grpc_version = 1,
cc_libs = [":worker_proto_cc"],
cc_stubby_versions = ["2"],
visibility = [
"//tensorflow:internal",
],
)
tf_proto_library_cc(
name = "master_proto",
srcs = ["protobuf/master.proto"],
cc_api_version = 2,
cc_libs = [":protos_all_cc"],
py_api_version = 2,
visibility = [
"//tensorflow:internal",
],
)
tf_proto_library_cc(
name = "master_service_proto",
srcs = ["protobuf/master_service.proto"],
has_services = 1,
cc_api_version = 2,
cc_grpc_version = 1,
cc_libs = [":master_proto_cc"],
cc_stubby_versions = ["2"],
py_api_version = 2,
visibility = [
"//tensorflow:internal",
],
)
cc_library(
name = "lib",
hdrs = [

View File

@ -0,0 +1,306 @@
# Description:
# A distributed runtime for TensorFlow, which allows graph execution
# to be distributed and performed in parallel across multiple
# processes.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
"**/*.h",
]),
)
load("//tensorflow:tensorflow.bzl", "tf_cc_tests")
# For platform specific build config
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_kernel_tests_linkstatic",
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
)
package(default_visibility = [
"//tensorflow:internal",
])
cc_library(
name = "worker_env",
hdrs = ["worker_env.h"],
deps = [],
)
cc_library(
name = "worker_interface",
hdrs = ["worker_interface.h"],
deps = [
":call_options",
"//tensorflow/core:lib",
"//tensorflow/core:worker_proto_cc",
],
)
cc_library(
name = "call_options",
srcs = ["call_options.cc"],
hdrs = ["call_options.h"],
deps = [
"//tensorflow/core:lib",
],
)
cc_test(
name = "call_options_test",
size = "small",
srcs = ["call_options_test.cc"],
deps = [
":call_options",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "worker_cache",
hdrs = ["worker_cache.h"],
deps = ["//tensorflow/core:protos_all_cc"],
)
cc_library(
name = "remote_device",
srcs = ["remote_device.cc"],
hdrs = ["remote_device.h"],
deps = [
":process_util",
":worker_cache",
":worker_interface",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core:worker_proto_cc",
],
)
cc_library(
name = "master_interface",
hdrs = ["master_interface.h"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
],
)
cc_library(
name = "master",
srcs = ["master.cc"],
hdrs = ["master.h"],
deps = [
":call_options",
":master_env",
":master_session_interface",
":process_util",
":remote_device",
":worker_cache",
":worker_interface",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core:worker_proto_cc",
],
)
cc_library(
name = "master_session",
srcs = ["master_session.cc"],
hdrs = ["master_session.h"],
deps = [
":master_env",
":master_session_interface",
":process_util",
":simple_graph_execution_state",
":worker_cache",
":worker_interface",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "build_graph_options",
srcs = ["build_graph_options.cc"],
hdrs = ["build_graph_options.h"],
deps = [
"//tensorflow/core:lib",
],
)
cc_library(
name = "simple_graph_execution_state",
srcs = ["simple_graph_execution_state.cc"],
hdrs = ["simple_graph_execution_state.h"],
deps = [
":build_graph_options",
":process_util",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
],
)
cc_library(
name = "rendezvous_mgr_interface",
srcs = [],
hdrs = ["rendezvous_mgr_interface.h"],
deps = [
"//tensorflow/core:framework",
],
)
cc_library(
name = "master_session_interface",
srcs = [],
hdrs = ["master_session_interface.h"],
deps = ["//tensorflow/core:lib"],
)
cc_library(
name = "base_rendezvous_mgr",
srcs = ["base_rendezvous_mgr.cc"],
hdrs = ["base_rendezvous_mgr.h"],
deps = [
":process_util",
":rendezvous_mgr_interface",
":worker_cache",
":worker_env",
":worker_interface",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow_opensource",
],
)
cc_library(
name = "master_env",
hdrs = ["master_env.h"],
)
cc_library(
name = "graph_mgr",
srcs = ["graph_mgr.cc"],
hdrs = ["graph_mgr.h"],
deps = [
":process_util",
":rendezvous_mgr_interface",
":worker_env",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core:worker_proto_cc",
],
)
cc_library(
name = "process_util",
srcs = ["process_util.cc"],
hdrs = ["process_util.h"],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:tensorflow_opensource",
],
)
cc_library(
name = "worker_cache_partial",
srcs = ["worker_cache_partial.cc"],
hdrs = ["worker_cache_partial.h"],
deps = [
":process_util",
":worker_cache",
":worker_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:worker_proto_cc",
],
)
cc_library(
name = "worker_cache_logger",
srcs = ["worker_cache_logger.cc"],
hdrs = ["worker_cache_logger.h"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
# TODO(mrry): Move executor_test.cc to ../common_runtime when once it no longer depends
# on grpc_testlib.
tf_cc_tests(
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
tests = [
"executor_test.cc",
"master_test.cc",
"remote_device_test.cc",
],
deps = [
"@grpc//:grpc++_unsecure",
":master",
":process_util",
":remote_device",
":worker_interface",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:master_service_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
],
)

View File

@ -0,0 +1,197 @@
# Distributed TensorFlow
This directory contains the initial open-source implementation of the
distributed TensorFlow runtime, using [gRPC](http://grpc.io) for inter-process
communication.
## Quick start
To get started, you will need to build the TensorFlow server binary
(`grpc_tensorflow_server`) and a gRPC-based client. Currently this is only
available using the source-based installation of TensorFlow, but it will be
included in future binary releases. You can build the server binary using one of
the following commands:
```shell
# CPU-only build.
$ bazel build -c opt //tensorflow/core/distributed_runtime/rpc:grpc_tensorflow_server
# GPU build.
$ bazel build -c opt --config=cuda //tensorflow/core/distributed_runtime/rpc:grpc_tensorflow_server
```
If you build the latest Python (PIP) package from source, it will contain a
gRPC-based client. If you are using a previous binary release, you may need to
rebuild and install an up-to-date PIP package by following
[these installation instructions](https://www.tensorflow.org/versions/master/get_started/os_setup.html#create-the-pip-package-and-install).
Once you have successfully built the distributed TensorFlow components, you can
test your installation by starting a server as follows:
```shell
# Start a TensorFlow server as a single-process "cluster".
$ bazel-bin/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server \
--cluster_spec='local|localhost:2222' --job_name=local --task_index=0 &
```
...then start a Python interpreter and create a remote session:
```python
$ python
>>> import tensorflow as tf
>>> c = tf.constant("Hello, distributed TensorFlow!")
>>> sess = tf.Session("grpc://localhost:2222")
>>> sess.run(c)
'Hello, distributed TensorFlow!'
```
## Cluster definition
The command-line arguments to `grpc_tensorflow_server` define the membership of a TensorFlow cluster. The `--cluster_spec` flag determines the set of processes in the cluster, as a list of *jobs*, each of which contains a list of *task* endpoints. All processes in the cluster must be started with the same `--cluster_spec`. Example values include:
<table>
<tr><th><code>--cluster_spec='...'</code></th><th>Available tasks</th>
<tr>
<td><code>local|localhost:2222</code></td><td><code>/job:local/task:0</code></td>
</tr>
<tr>
<td><code>local|localhost:2222;localhost:2223</code></td><td><code>/job:local/task:0</code><br/><code>/job:local/task:1</code></td>
</tr>
<tr>
<td><code>worker|worker0:2222;worker1:2222;worker2:2222,</code><br/><code>ps|ps0:2222;ps1:2222</code></td><td><code>/job:worker/task:0</code><br/><code>/job:worker/task:1</code><br/><code>/job:worker/task:2</code><br/><code>/job:ps/task:0</code><br/><code>/job:ps/task:1</code></td>
</tr>
</table>
The `--job_name` and `--task_index` flags indicate which task will run in this
process, out of the jobs and tasks defined in `--cluster_spec`. For example,
`--job_name=local --task_index=0` means that the process will be task
`/job:local/task:0`, and TensorFlow devices in the process will have names
starting with that prefix.
**N.B.** Manually specifying these command lines can be tedious, especially for
large clusters. We are working on tools for launching tasks programmatically,
e.g. using a cluster manager like [Kubernetes](http://kubernetes.io). If there
are particular cluster managers for which you'd like to see support, please
raise a [GitHub issue](https://github.com/tensorflow/tensorflow/issues).
## Specifying distributed devices in your model
To place operations on a particular process, you can use the same
[`tf.device()`](https://www.tensorflow.org/versions/master/api_docs/python/framework.html#device)
function that is used to specify whether ops run on the CPU or GPU. For example:
```python
with tf.device("/job:ps/task:0"):
weights_1 = tf.Variable(...)
biases_1 = tf.Variable(...)
with tf.device("/job:ps/task:1"):
weights_2 = tf.Variable(...)
biases_2 = tf.Variable(...)
with tf.device("/job:worker/task:7"):
input, labels = ...
layer_1 = tf.nn.relu(tf.matmul(input, weights_1) + biases_1)
logits = tf.nn.relu(tf.matmul(layer_1, weights_2) + biases_2)
# ...
train_op = ...
with tf.Session("grpc://worker7:2222") as sess:
for _ in range(10000):
sess.run(train_op)
```
In the above example, the variables are created on two tasks in the `ps` job,
and the compute-intensive part of the model is created in the `worker`
job. TensorFlow will insert the appropriate data transfers between the jobs
(from `ps` to `worker` for the forward pass, and from `worker` to `ps` for
applying gradients).
## Replicated training
A common training configuration ("data parallel training") involves multiple
tasks in a `worker` job training the same model, using shared parameters hosted
in a one or more tasks in a `ps` job. Each task will typically run on a
different machine. There are many ways to specify this structure in TensorFlow,
and we are building libraries that will simplify the work of specifying a
replicated model. Possible approaches include:
* Building a single graph containing one set of parameters (in `tf.Variable`
nodes pinned to `/job:ps`), and multiple copies of the "model" pinned to
different tasks in `/job:worker`. Each copy of the model can have a different
`train_op`, and one or more client threads can call `sess.run(train_ops[i])`
for each worker `i`. This implements *asynchronous* training.
This approach uses a single `tf.Session` whose target is one of the workers in
the cluster.
* As above, but where the gradients from all workers are averaged. See the
[CIFAR-10 multi-GPU trainer](https://www.tensorflow.org/code/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py)
for an example of this form of replication. The implements *synchronous* training
* The "distributed trainer" approach uses multiple graphs&mdash;one per
worker&mdash;where each graph contains one set of parameters (pinned to
`/job:ps`) and one copy of the model (pinned to a particular
`/job:worker/task:i`). The "container" mechanism is used to share variables
between different graphs: when each variable is constructed, the optional
`container` argument is specified with the same value in each copy of the
graph. For large models, this can be more efficient, because the overall graph
is smaller.
This approach uses multiple `tf.Session` objects: one per worker process,
where the `target` of each is the address of a different worker. The
`tf.Session` objects can all be created in a single Python client, or you can
use multiple Python clients to better distribute the trainer load.
## Glossary
<dl>
<dt>Client</dt>
<dd>
A client is typically a program that builds a TensorFlow graph and
constructs a `tensorflow::Session` to interact with a cluster. Clients are
typically written in Python or C++. A single client process can directly
interact with multiple TensorFlow servers (see "Replicated training" above),
and a single server can serve multiple clients.
</dd>
<dt>Cluster</dt>
<dd>
A TensorFlow cluster comprises one or more TensorFlow servers, divided into
a set of named jobs, which in turn comprise lists of tasks. A cluster is
typically dedicated to a particular high-level objective, such as training a
neural network, using many machines in parallel.
</dd>
<dt>Job</dt>
<dd>
A job comprises a list of "tasks", which typically serve a common
purpose. For example, a job named `ps` (for "parameter server") typically
hosts nodes that store and update variables; while a job named `worker`
typically hosts stateless nodes that perform compute-intensive tasks.
The tasks in a job typically run on different machines.
</dd>
<dt>Master service</dt>
<dd>
An RPC service that provides remote access to a set of distributed
devices. The master service implements the <code>tensorflow::Session</code>
interface, and is responsible for coordinating work across one or more
"worker services".
</dd>
<dt>Task</dt>
<dd>
A task typically corresponds to a single TensorFlow server process,
belonging to a particular "job" and with a particular index within that
job's list of tasks.
</dd>
<dt>TensorFlow server</dt>
<dd>
A process running the <code>grpc_tensorflow_server</code> binary, which is a
member of a cluster, and exports a "master service" and "worker service".
</dd>
<dt>Worker service</dt>
<dd>
An RPC service that executes parts of a TensorFlow graph using its local
devices. A worker service implements <a
href="./worker_service.proto"><code>worker_service.proto</code></a>.
</dd>
</dl>

View File

@ -0,0 +1,318 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include <unordered_set>
#include <vector>
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* env) : worker_env_(env) {}
BaseRendezvousMgr::~BaseRendezvousMgr() {
for (auto& p : table_) {
BaseRemoteRendezvous* rendez = p.second;
rendez->StartAbort(errors::Aborted("Shutdown"));
rendez->Unref();
}
}
Rendezvous* BaseRendezvousMgr::Find(int64 step_id) {
return FindOrCreate(step_id);
}
BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
mutex_lock l(mu_);
Table::iterator iter = table_.find(step_id);
if (iter == table_.end()) {
auto rr = Create(step_id, worker_env_);
iter = table_.insert({step_id, rr}).first;
}
iter->second->Ref();
return iter->second;
}
void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, const string& key,
Rendezvous::DoneCallback done) {
BaseRemoteRendezvous* rendez = FindOrCreate(step_id);
rendez->RecvLocalAsync(
key, [rendez, done](const Status& s, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& v,
bool dead) {
rendez->Unref();
done(s, send_args, recv_args, v, dead);
});
}
Status BaseRendezvousMgr::RecvLocal(int64 step_id, const string& key,
Tensor* val, bool* is_dead) {
Status ret;
Notification n;
RecvLocalAsync(step_id, key,
[val, is_dead, &ret, &n](const Status& s,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& v, const bool dead) {
ret = s;
*val = v;
*is_dead = dead;
n.Notify();
});
n.WaitForNotification();
return ret;
}
void BaseRendezvousMgr::Cleanup(int64 step_id) {
Rendezvous* rendez = nullptr;
{
mutex_lock l(mu_);
Table::iterator iter = table_.find(step_id);
if (iter != table_.end()) {
rendez = iter->second;
table_.erase(iter);
}
}
if (!rendez) return;
rendez->StartAbort(errors::Aborted("Cleanup ", step_id));
rendez->Unref();
}
void BaseRendezvousMgr::CleanupAll() {
std::vector<Rendezvous*> rendezs;
{
mutex_lock l(mu_);
for (const auto& entry : table_) {
rendezs.push_back(entry.second);
}
table_.clear();
}
for (auto rendez : rendezs) {
rendez->StartAbort(errors::Aborted("Shutdown"));
rendez->Unref();
}
}
BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
bool tolerate_dup_recv)
: env_(env),
step_id_(step_id),
tolerate_dup_recv_(tolerate_dup_recv),
local_(NewLocalRendezvous(tolerate_dup_recv)) {}
BaseRemoteRendezvous::~BaseRemoteRendezvous() {
CHECK(active_.empty());
local_->Unref();
}
// Returns true if "device_name" is a valid full name of local device
// of the "worker". This helper is purely based on the worker name
// and device name and does no lookups in the worker->device_mgr.
static bool IsLocalDevice(const WorkerEnv& worker,
const StringPiece device_name) {
return device_name.starts_with(worker.worker_name);
}
Status BaseRemoteRendezvous::Send(const string& key,
const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) {
VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << key;
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
}
Rendezvous::ParsedKey parsed;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
if (!IsLocalDevice(*env_, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
env_->worker_name);
}
// Buffers "val" and "device_context" in local_.
return local_->Send(key, args, val, is_dead);
}
Status BaseRemoteRendezvous::ParseKey(const string& key, bool is_src,
Rendezvous::ParsedKey* parsed) {
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
}
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed));
if (is_src && !IsLocalDevice(*env_, parsed->src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
env_->worker_name);
}
if (!is_src && !IsLocalDevice(*env_, parsed->dst_device)) {
return errors::InvalidArgument("Invalid rendezvous key (dst): ", key, " @ ",
env_->worker_name);
}
return Status::OK();
}
void BaseRemoteRendezvous::SameWorkerRecvDone(
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
StatusCallback done) {
// Do a quick copy (sharing the underlying buffer) if both tensors
// are on host memory.
const bool src_host =
(send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
const bool dst_host =
(recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
if (src_host && dst_host) {
*out = in;
done(Status::OK());
return;
}
// This copy must involve a GPU. Hence, "in" must support DMA
// (e.g., string tensors do not work on GPU).
if (!DMAHelper::CanUseDMA(&in)) {
done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()),
" tensor may not be copied from/to a GPU."));
return;
}
Device* src_device;
Status s = env_->device_mgr->LookupDevice(parsed.src_device, &src_device);
if (!s.ok()) {
done(s);
return;
}
Device* dst_device;
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
if (!s.ok()) {
done(s);
return;
}
AllocatorAttributes attr = recv_args.alloc_attrs;
attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
recv_args.alloc_attrs.gpu_compatible());
Allocator* out_allocator = dst_device->GetAllocator(attr);
Tensor copy(out_allocator, in.dtype(), in.shape());
*out = copy;
// The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
// etc.
CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
recv_args.device_context, src_device, dst_device,
send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
done);
}
bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
DeviceNameUtils::ParsedName dst) {
return DeviceNameUtils::IsSameAddressSpace(src, dst);
}
void BaseRemoteRendezvous::RecvAsync(const string& key,
const Rendezvous::Args& recv_args,
DoneCallback done) {
VLOG(1) << "RemoteRendezvous Recv " << this << " " << key;
Rendezvous::ParsedKey parsed;
Status s = ParseKey(key, false /*!is_src*/, &parsed);
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), false);
return;
}
// Are src and dst in the same worker?
if (IsSameWorker(parsed.src, parsed.dst)) {
// Recv the tensor from local_.
local_->RecvAsync(
key, recv_args, [this, parsed, done](const Status& status,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& in, bool is_dead) {
Status s = status;
Tensor* out = new Tensor;
StatusCallback final_callback = [done, send_args, recv_args, out,
is_dead](const Status& s) {
done(s, send_args, recv_args, *out, is_dead);
delete out;
};
if (s.ok()) {
SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
final_callback);
} else {
final_callback(s);
}
});
return;
} else {
RecvFromRemoteAsync(key, parsed, recv_args, done);
}
}
void BaseRemoteRendezvous::RecvLocalAsync(const string& key,
DoneCallback done) {
Rendezvous::ParsedKey parsed;
Status s = ParseKey(key, true /* is_src */, &parsed);
if (!s.ok()) {
done(s, Args(), Args(), Tensor(), false);
return;
}
local_->RecvAsync(key, Args(), done);
}
void BaseRemoteRendezvous::StartAbort(const Status& s) {
CHECK(!s.ok());
local_->StartAbort(s);
{
// Aborts all active RecvTensor calls.
mutex_lock l(mu_);
if (status_.ok()) {
status_ = s;
for (BaseRecvTensorCall* call : active_) {
call->StartAbort(s);
}
active_.clear();
}
}
}
void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call) {
mutex_lock l(mu_);
if (!status_.ok()) {
call->StartAbort(status_);
} else {
CHECK(active_.insert(call).second);
}
}
void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
mutex_lock l(mu_);
active_.erase(call);
}
} // end namespace tensorflow

View File

@ -0,0 +1,212 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
class BaseRemoteRendezvous;
class BaseRecvTensorCall;
// RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id"
// corresponds to one local rendezvous instance managed by a
// RendezvousMgr.
//
// E.g.,
// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
// fork execution of a graph executor using "rendez" on thread 1;
// fork execution of another graph executor using "rendez" on thread 2;
// ...
// join threads 1 and 2;
//
// In the example above, execution in thread 1 and 2 communicates with
// each other by send/recv operations through `rendez`.
//
// Tensors sent and received through a rendezvous managed by this
// RendezvousMgr must have keys generated by Rendezvous::CreateKey().
class BaseRendezvousMgr : public RendezvousMgrInterface {
public:
explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
~BaseRendezvousMgr() override;
// Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance.
Rendezvous* Find(int64 step_id) override;
// Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs.
//
// This method is used by the rpc handler of RecvTensor.
void RecvLocalAsync(int64 step_id, const string& key,
Rendezvous::DoneCallback done) override;
// Synchronous wrapper for RecvLocalAsync.
Status RecvLocal(int64 step_id, const string& key, Tensor* val,
bool* is_dead) override;
// Removes rendezvous for "step_id".
//
// TODO(zhifengc): Have a background thread in worker that
// periodically calls CleanupAll().
void Cleanup(int64 step_id) override;
// Removed all rendezvous.
void CleanupAll() override;
protected:
virtual BaseRemoteRendezvous* Create(int64 step_id,
const WorkerEnv* worker_env) = 0;
private:
// Maps step_id to rendezvous.
typedef std::unordered_map<int64, BaseRemoteRendezvous*> Table;
// Not owned.
const WorkerEnv* const worker_env_;
mutex mu_;
Table table_ GUARDED_BY(mu_);
BaseRemoteRendezvous* FindOrCreate(int64 step_id);
TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr);
};
// RemoteRendezvous is a Rendezvous which can handle either
// the producer or consumer being in a remote process.
//
// Buffering of Tensor values is delegated to a "local" Rendezvous
// obtained from NewLocalRendezvous(). This class just adds
// functionality to coordinate with remote workers.
class BaseRemoteRendezvous : public Rendezvous {
public:
BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
bool tolerate_dup_recv);
// Forwards to local_, where the Tensor "val" will be buffered and
// any waiting callback stored.
Status Send(const string& key, const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) override;
// This method is called only by the RecvOp. It tests to see
// whether the value will be produced by a local or remote device
// and handles accordingly. In the local case it forwards to
// local_, in the remote case it initiates an RPC request.
void RecvAsync(const string& key, const Rendezvous::Args& args,
DoneCallback done) override;
void StartAbort(const Status& status) override;
// This method is called only by the local Worker, forwarded through
// the same method on RendezvousMgr. This occurs when the Worker
// has received a RecvTensor request, either locally or over the
// network. In either case it needs to retrieve a locally buffered
// value from local_, and give it to its caller.
//
// Runs "done" as soon as the tensor for "key" is available or an error
// is detected.
//
// REQUIRES: "key" is one that will be Saved into the local rendezvous.
void RecvLocalAsync(const string& key, DoneCallback done);
protected:
virtual void RecvFromRemoteAsync(const string& key,
const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) = 0;
// Returns true if "src" and "dst" are located in the same worker,
// and hence may use a local rendezvous.
virtual bool IsSameWorker(DeviceNameUtils::ParsedName src,
DeviceNameUtils::ParsedName dst);
// If aborted, aborts "call". Otherwise, adds "call" into active_.
void RegisterCall(BaseRecvTensorCall* call);
// Removes "call" from active_ if "call" is in active_.
void DeregisterCall(BaseRecvTensorCall* call);
~BaseRemoteRendezvous() override;
const WorkerEnv* const env_; // Not owned.
const int64 step_id_;
private:
const bool tolerate_dup_recv_;
Rendezvous* local_; // Owns a Ref on this object.
mutable mutex mu_;
// Status given by StartAbort() if any.
Status status_ GUARDED_BY(mu_);
// Active outstanding RecvTensor calls.
std::unordered_set<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
// Parses "key" into "parsed". If "is_src" is true, checks that the
// rendezvous key's source is in this process. If "is_src" is false,
// checks that the rendezvous key's destination is in this process.
Status ParseKey(const string& key, bool is_src,
Rendezvous::ParsedKey* parsed);
// Callback handling the case when a rendezvous has been
// accomplished in local_ and the consumer is local to this process.
// Tensor "in" will be copied into "out". The key "parsed" encodes
// the src and dst devices.
void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& in_args,
const Rendezvous::Args& out_args, const Tensor& in,
Tensor* out, StatusCallback done);
TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
};
class BaseRecvTensorCall {
public:
BaseRecvTensorCall() {}
virtual ~BaseRecvTensorCall() {}
virtual void Start(std::function<void()> recv_done) = 0;
virtual void StartAbort(const Status& s) = 0;
virtual Status status() const = 0;
private:
TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall);
};
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_

View File

@ -0,0 +1,38 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/build_graph_options.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
string BuildGraphOptions::DebugString() const {
string rv = "Feed endpoints: ";
for (auto& s : feed_endpoints) {
strings::StrAppend(&rv, s, ", ");
}
strings::StrAppend(&rv, "\nFetch endpoints: ");
for (auto& s : fetch_endpoints) {
strings::StrAppend(&rv, s, ", ");
}
strings::StrAppend(&rv, "\nTarget nodes: ");
for (auto& s : target_nodes) {
strings::StrAppend(&rv, s, ", ");
}
return rv;
}
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_
#include <vector>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
struct BuildGraphOptions {
std::vector<string> feed_endpoints;
std::vector<string> fetch_endpoints;
// TODO(vrv): Remove this when we unify target_nodes and fetch_endpoint,
// the former via "ref" fetch_endpoints.
std::vector<string> target_nodes;
string DebugString() const;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_

View File

@ -0,0 +1,44 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
CallOptions::CallOptions() {}
void CallOptions::StartCancel() {
mutex_lock l(mu_);
if (cancel_func_ != nullptr) {
// NOTE: We must call the cancel_func_ with mu_ held. This ensure
// that ClearCancelCallback() does not race with StartCancel().
cancel_func_();
// NOTE: We can clear cancel_func_ if needed.
}
}
void CallOptions::SetCancelCallback(CancelFunction cancel_func) {
mutex_lock l(mu_);
cancel_func_ = cancel_func;
}
void CallOptions::ClearCancelCallback() {
mutex_lock l(mu_);
cancel_func_ = nullptr;
}
} // end namespace tensorflow

View File

@ -0,0 +1,72 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_
#include <functional>
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Options passed to interface calls. This class provides portable
// functionality across different RPC systems on top of
// platform-specific mechanisms (for client and server contexts,
// cancellation, etc.).
//
// TODO(zhifengc): Maybe change all RPC methods to take CallOptions.
class CallOptions {
public:
CallOptions();
// Cancellation.
//
// The caller may call StartCancel() anytime as long as this
// CallOptions object is alive. The callee may or may not receive
// the cancellation notification depending on the rpc layer
// implementation.
void StartCancel();
// The callee (the rpc layer implementation) must set a cancellation
// notifier before its blocking operation and clear the notifier
// before the call returns.
//
// "cancel_func" may be called zero, once or more time. Therefore, it
// should _not_ be responsible for memory management of any objects.
//
// "cancel_func" must be very light-weight. It should not block on
// IO or locking. Typically, it just calls the rpc implementation
// layer's specific cancellation mechanism and does nothing else.
//
// NOTE: "cancel_func" itself is pass-by-value. Therefore, we do not
// worry about its ownership here.
typedef std::function<void()> CancelFunction;
void SetCancelCallback(CancelFunction cancel_func);
void ClearCancelCallback();
private:
mutex mu_;
CancelFunction cancel_func_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(CallOptions);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_

View File

@ -0,0 +1,39 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
TEST(CallOptions, Cancel) {
int num_calls = 0;
CallOptions opts;
opts.StartCancel();
EXPECT_EQ(num_calls, 0);
opts.SetCancelCallback([&num_calls]() { num_calls++; });
EXPECT_EQ(num_calls, 0);
opts.StartCancel();
EXPECT_EQ(num_calls, 1);
opts.StartCancel();
EXPECT_EQ(num_calls, 2);
opts.ClearCancelCallback();
EXPECT_EQ(num_calls, 2);
opts.StartCancel();
EXPECT_EQ(num_calls, 2);
}
} // namespace tensorflow

View File

@ -0,0 +1,407 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class ExecutorTest : public ::testing::Test {
protected:
ExecutorTest()
: device_(DeviceFactory::NewDevice("CPU", {},
"/job:localhost/replica:0/task:0")),
step_stats_collector_(&step_stats_) {
SessionOptions options;
thread_pool_ = ComputePool(options);
}
~ExecutorTest() override {
// There should always be exactly one Ref left on the Rendezvous
// when the test completes.
CHECK(rendez_->Unref());
delete exec_;
delete device_;
}
// Resets executor_ with a new executor based on a graph 'gdef'.
void Create(const Graph* graph) {
const int version = graph->versions().producer();
LocalExecutorParams params;
params.device = device_;
params.create_kernel = [this, version](const NodeDef& ndef,
OpKernel** kernel) {
return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
};
params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel);
};
delete exec_;
TF_CHECK_OK(NewLocalExecutor(params, graph, &exec_));
runner_ = [this](std::function<void()> fn) { thread_pool_->Schedule(fn); };
rendez_ = NewLocalRendezvous();
}
Status Run(Rendezvous* rendez) {
Executor::Args args;
args.rendezvous = rendez;
args.stats_collector = &step_stats_collector_;
args.runner = runner_;
return exec_->Run(args);
}
thread::ThreadPool* thread_pool_ = nullptr;
Device* device_ = nullptr;
Executor* exec_ = nullptr;
StepStatsCollector step_stats_collector_;
StepStats step_stats_;
Executor::Args::Runner runner_;
Rendezvous* rendez_ = nullptr;
};
// A float val -> Tensor<float>
Tensor V(const float val) {
Tensor tensor(DT_FLOAT, TensorShape({}));
tensor.scalar<float>()() = val;
return tensor;
}
// A int32 val -> Tensor<int32>
Tensor VI(const int32 val) {
Tensor tensor(DT_INT32, TensorShape({}));
tensor.scalar<int32>()() = val;
return tensor;
}
// A bool val -> Tensor<bool>
Tensor VB(const bool val) {
Tensor tensor(DT_BOOL, TensorShape({}));
tensor.scalar<bool>()() = val;
return tensor;
}
// A double val -> Tensor<double>
Tensor VD(const double val) {
Tensor tensor(DT_DOUBLE, TensorShape({}));
tensor.scalar<double>()() = val;
return tensor;
}
// Tensor<float> -> a float val.
float V(const Tensor& tensor) {
CHECK_EQ(tensor.dtype(), DT_FLOAT);
CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
return tensor.scalar<float>()();
}
static uint64 kIncarnation = 1; // Uses in following tests.
string Key(const string& sender, const uint64 incarnation,
const string& receiver, const string& name) {
return Rendezvous::CreateKey(sender, incarnation, receiver, name,
FrameAndIter(0, 0));
}
#define ALICE "/job:j/replica:0/task:0/cpu:0"
#define BOB "/job:j/replica:0/task:0/gpu:0"
TEST_F(ExecutorTest, SimpleAdd) {
// c = a + b
Graph* g = new Graph(OpRegistry::Global());
auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB);
auto tmp = test::graph::Add(g, in0, in1);
test::graph::Send(g, tmp, "c", BOB, 1, ALICE);
Create(g);
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
false)); // in0 = 1.0
TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), args, V(1.0),
false)); // in1 = 1.0
TF_ASSERT_OK(Run(rendez_));
Tensor out = V(-1);
bool is_dead = false;
TF_ASSERT_OK(
rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
EXPECT_EQ(2.0, V(out)); // out = 1.0 + 1.0 = 2.0
}
TEST_F(ExecutorTest, SelfAdd) {
// v0 <- a
// v1 = v0 + v0
// v2 = v1 + v1
// ... ...
// v10 = v9 + v9
//
// b <- v10
// All nodes are executed by one thread.
Graph* g = new Graph(OpRegistry::Global());
auto v = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
const int N = 10;
for (int i = 1; i <= N; ++i) {
v = test::graph::Add(g, v, v);
}
// out <- v10
test::graph::Send(g, v, "b", BOB, 1, ALICE);
Create(g);
Rendezvous::Args args;
// a = 1.0
TF_ASSERT_OK(
rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false));
TF_ASSERT_OK(Run(rendez_));
Tensor out = V(-1);
bool is_dead = false;
TF_ASSERT_OK(
rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead));
EXPECT_EQ(1024.0, V(out)); // b=v10=2*v9=4*v8=...=1024*a=1024.0
}
// Builds a graph which adds N copies of one variable "in". I.e.,
// a + a + a + ... + a
// The returned graph is parenthesized ramdonly. I.e.,
// a + ((a + a) + a)
// (a + a) + (a + a)
// ((a + a) + a) + a
// are all possibly generated.
void BuildTree(int N, Graph* g) {
CHECK_GT(N, 1);
// A single input node "in".
auto in = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
std::vector<Node*> nodes;
int i = 0;
// Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
for (; i < N; ++i) {
nodes.push_back(test::graph::Identity(g, in, 0));
}
random::PhiloxRandom philox(testing::RandomSeed(), 17);
random::SimplePhilox rnd(&philox);
while (nodes.size() > 1) {
// Randomly pick two from nodes and add them. The resulting node
// is named lik n10, n11, .... and is put back into "nodes".
int x = rnd.Uniform(nodes.size());
auto in0 = nodes[x];
nodes[x] = nodes.back();
nodes.resize(nodes.size() - 1);
x = rnd.Uniform(nodes.size());
auto in1 = nodes[x];
// node = in0 + in1.
nodes[x] = test::graph::Add(g, in0, in1);
}
// The final output node "out".
test::graph::Send(g, nodes.back(), "b", BOB, 1, ALICE);
}
TEST_F(ExecutorTest, RandomTree) {
Graph* g = new Graph(OpRegistry::Global());
BuildTree(4096, g);
Create(g);
Rendezvous::Args args;
TF_ASSERT_OK(
rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false));
TF_ASSERT_OK(Run(rendez_));
Tensor out = V(-1);
bool is_dead = false;
TF_ASSERT_OK(
rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead));
EXPECT_EQ(4096.0, V(out));
}
void BuildConcurrentAddAssign(Graph* g) {
auto one = test::graph::Constant(g, V(1.0));
// A variable holds one float.
auto var = test::graph::Var(g, DT_FLOAT, TensorShape({}));
// Initilize the variable with 1.0.
auto init = test::graph::Assign(g, var, one);
// Output
auto out = test::graph::Send(g, var, "out", ALICE, kIncarnation, BOB);
// Have many concurrent computation. Each does v = v + 1.
for (int i = 0; i < 1024; ++i) {
auto add = test::graph::Add(g, var, one);
g->AddControlEdge(init, add); // Ensures run after init.
auto assign = test::graph::Assign(g, var, add);
g->AddControlEdge(assign, out);
}
}
#ifndef THREAD_SANITIZER
TEST_F(ExecutorTest, ConcurrentAddAssign) {
Graph* g = new Graph(OpRegistry::Global());
BuildConcurrentAddAssign(g);
Create(g);
for (int iters = 0; iters < 16; ++iters) {
Rendezvous* rendez = NewLocalRendezvous();
TF_ASSERT_OK(Run(rendez));
Rendezvous::Args args;
Tensor out;
bool is_dead;
TF_ASSERT_OK(rendez->Recv(Key(ALICE, kIncarnation, BOB, "out"), args, &out,
&is_dead));
VLOG(1) << "Get " << V(out);
EXPECT_LE(V(out), 1025.0);
rendez->Unref();
}
}
#endif
TEST_F(ExecutorTest, SimpleSwitchLive) {
Graph* g = new Graph(OpRegistry::Global());
auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
auto in1 = test::graph::Constant(g, VB(false));
auto tmp = test::graph::Switch(g, in0, in1);
test::graph::Send(g, tmp, "c", BOB, 1, ALICE);
Create(g);
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
false)); // in0 = 1.0
TF_ASSERT_OK(Run(rendez_));
Tensor out = V(-1);
bool is_dead = false;
TF_ASSERT_OK(
rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
EXPECT_EQ(1.0, V(out)); // out = 1.0
EXPECT_FALSE(is_dead);
}
TEST_F(ExecutorTest, SimpleSwitchDead) {
Graph* g = new Graph(OpRegistry::Global());
auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
auto in1 = test::graph::Constant(g, VB(true));
auto tmp = test::graph::Switch(g, in0, in1);
test::graph::Send(g, tmp, "c", BOB, 1, ALICE);
Create(g);
Rendezvous::Args args;
TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
false)); // in0 = 1.0
TF_ASSERT_OK(Run(rendez_));
Tensor out = V(-1);
bool is_dead = false;
TF_ASSERT_OK(
rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
EXPECT_TRUE(is_dead);
}
TEST_F(ExecutorTest, Abort) {
// e = a + b + c + d
Graph* g = new Graph(OpRegistry::Global());
auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB);
auto in2 = test::graph::Recv(g, "c", "float", ALICE, 1, BOB);
auto in3 = test::graph::Recv(g, "d", "float", ALICE, 1, BOB);
auto add0 = test::graph::Add(g, in0, in1);
auto add1 = test::graph::Add(g, in2, in3);
auto add2 = test::graph::Add(g, add0, add1);
test::graph::Send(g, add2, "e", BOB, 1, ALICE);
Create(g);
// Needs 4 inputs (recv). One of them is aborted.
rendez_->Ref();
SchedClosure([this]() {
Env::Default()->SleepForMicroseconds(100 * 1000);
Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"),
Rendezvous::Args(), V(1.0), false);
rendez_->Unref();
});
rendez_->Ref();
SchedClosure([this]() {
Env::Default()->SleepForMicroseconds(100 * 1000);
Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"),
Rendezvous::Args(), V(1.0), false);
rendez_->Unref();
});
rendez_->Ref();
SchedClosure([this]() {
Env::Default()->SleepForMicroseconds(100 * 1000);
Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "c"),
Rendezvous::Args(), V(1.0), false);
rendez_->Unref();
});
rendez_->Ref();
SchedClosure([this]() {
Env::Default()->SleepForMicroseconds(100 * 1000);
rendez_->StartAbort(errors::Aborted(""));
rendez_->Unref();
});
EXPECT_TRUE(errors::IsAborted(Run(rendez_)));
Tensor out = V(-1);
bool is_dead = false;
EXPECT_TRUE(errors::IsAborted(rendez_->Recv(
Key(BOB, kIncarnation, ALICE, "c"), Rendezvous::Args(), &out, &is_dead)));
// At this point there can still be pending (albeit Aborted) Send
// closures holding Refs on rendez_. We need to wait for them, or
// else there can be a memory leak at termination.
while (!rendez_->RefCountIsOne())
;
}
TEST_F(ExecutorTest, RecvInvalidDtype) {
Graph* g = new Graph(OpRegistry::Global());
// An input vector of type float of size 1.
auto one = test::graph::Recv(g, "one", "float", ALICE, 1, BOB);
// A floating point variable vector of size 1.
auto var = test::graph::Var(g, DT_FLOAT, TensorShape({1}));
// Initialize the variable with input.
auto init = test::graph::Assign(g, var, one);
// Output
auto* two = test::graph::Send(g, var, "two", BOB, 1, ALICE);
g->AddControlEdge(init, two); // Ensures run after init.
Create(g);
Rendezvous* rendez = NewLocalRendezvous();
// Send a double instead of float.
TF_ASSERT_OK(rendez->Send(Key(ALICE, 1, BOB, "one"), Rendezvous::Args(),
VD(1.0), false));
// Fails due to invalid dtype.
EXPECT_TRUE(errors::IsInternal(Run(rendez)));
Tensor output;
bool is_dead;
EXPECT_TRUE(errors::IsInternal(rendez->Recv(
Key(BOB, 1, ALICE, "two"), Rendezvous::Args(), &output, &is_dead)));
rendez->Unref();
}
TEST_F(ExecutorTest, RecvInvalidRefDtype) {
Graph* g = new Graph(OpRegistry::Global());
// A var that always produces as invalid dtype.
auto var = test::graph::InvalidRefType(g, DT_FLOAT, DT_DOUBLE);
test::graph::Send(g, var, "out", BOB, 1, ALICE);
Create(g);
Rendezvous* rendez = NewLocalRendezvous();
EXPECT_TRUE(errors::IsInternal(Run(rendez)));
Tensor output;
bool is_dead;
EXPECT_TRUE(errors::IsInternal(rendez->Recv(
Key(BOB, 1, ALICE, "out"), Rendezvous::Args(), &output, &is_dead)));
rendez->Unref();
}
} // namespace tensorflow

View File

@ -0,0 +1,368 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include <vector>
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/config.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
GraphMgr::GraphMgr(const WorkerEnv* worker_env)
: worker_env_(worker_env), table_(5) {}
GraphMgr::~GraphMgr() {
for (auto p : table_) p.second->Unref();
}
GraphMgr::Item::~Item() {
for (const auto& unit : this->units) {
CHECK_NOTNULL(unit.device);
delete unit.root;
delete unit.lib;
unit.device->op_segment()->RemoveHold(this->session);
}
delete this->lib_def;
}
// NOTE: node->device_name() is not set by GraphConstructor. We
// expects that NodeDef in GraphDef given to workers fully specifies
// device names.
static string SplitByDevice(const Node* node) {
return node->assigned_device_name();
}
// Validates "gdef" device specifications.
static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
DeviceNameUtils::ParsedName parsed;
for (const auto& ndef : gdef.node()) {
if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) {
return errors::InvalidArgument("Missing device name in: ",
SummarizeNodeDef(ndef));
}
}
return Status::OK();
}
// Creates executors given a graph definition "gdef" of a "session".
// If a node in "gdef" is shared by other graphs in "session", the
// same op kernel is reused. E.g., typically a params node is shared
// by multiple graphs in a session.
//
// If "gdef" is assigned to multiple devices, extra nodes (e.g.,
// send/recv nodes) maybe added. The extra nodes' name are generated
// by calling "new_name(old_name)".
//
// "executors" are filled with one executor per device if success and
// the caller takes the ownership of returned executors.
Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options, Item* item) {
item->session = session;
item->lib_def = new FunctionLibraryDefinition(gdef.library());
TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
if (gdef.versions().producer() >= 5) {
// Validate the graph: we assume that merging two valid graphs
// should maintain graph validity.
TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *item->lib_def));
}
// Constructs the graph out of "gdef".
Graph graph(item->lib_def);
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = true;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
// Splits "graph" into multiple subgraphs by device names.
std::unordered_map<string, GraphDef> partitions;
PartitionOptions popts;
popts.node_to_loc = SplitByDevice;
popts.new_name = [this](const string& prefix) {
mutex_lock l(mu_);
return strings::StrCat(prefix, "_G", next_id_++);
};
popts.get_incarnation = [this](const string& name) {
Device* device = nullptr;
Status s = worker_env_->device_mgr->LookupDevice(name, &device);
if (s.ok()) {
return device->attributes().incarnation();
} else {
return PartitionOptions::kIllegalIncarnation;
}
};
popts.control_flow_added = true;
popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
if (popts.scheduling_for_recvs) {
TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
}
thread::ThreadPool* pool = worker_env_->compute_pool;
auto runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
LocalExecutorParams params;
Status s;
item->units.reserve(partitions.size());
const auto& optimizer_opts = graph_options.optimizer_options();
GraphOptimizer optimizer(optimizer_opts);
for (auto&& p : partitions) {
const string& device_name = p.first;
GraphDef* def = &p.second;
item->units.resize(item->units.size() + 1);
ExecutionUnit* unit = &(item->units.back());
// Find the device.
s = worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
if (!s.ok()) break;
// Construct the subgraph.
Graph* subgraph = new Graph(item->lib_def);
// Give the device an opportunity to rewrite its subgraph.
unit->device->MaybeRewriteGraph(gdef.library(), def);
s = ConvertGraphDefToGraph(opts, *def, subgraph);
if (!s.ok()) {
delete subgraph;
break;
}
// Top-level nodes in the graph uses the op segment to cache
// kernels. Therefore, as long as the executor is alive, we need
// to ensure the kernels cached for the session are alive.
auto opseg = unit->device->op_segment();
opseg->AddHold(session);
// Function library runtime.
unit->lib = NewFunctionLibraryRuntime(
unit->device, runner, def->versions().producer(), item->lib_def,
graph_options.optimizer_options());
// Construct the root executor for the subgraph.
params.device = unit->device;
auto lib = unit->lib;
params.function_library = lib;
params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
OpKernel** kernel) {
// Caches the kernel only if the node is stateful.
if (!lib->IsStateful(ndef.op())) {
return lib->CreateKernel(ndef, kernel);
}
auto create_fn = [lib, &ndef](OpKernel** kernel) {
return lib->CreateKernel(ndef, kernel);
};
// Kernels created for subgraph nodes need to be cached. On
// cache miss, create_fn() is invoked to create a kernel based
// on the function library here + global op registry.
return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
};
params.delete_kernel = [lib](OpKernel* kernel) {
// If the node is stateful, opseg owns it. Otherwise, delete it.
if (kernel && !lib->IsStateful(kernel->type_string())) {
delete kernel;
}
};
optimizer.Optimize(lib, &subgraph);
s = ValidateMemoryTypes(DeviceType(unit->device->device_type()), subgraph);
if (!s.ok()) {
delete subgraph;
break;
}
s = NewLocalExecutor(params, subgraph, &unit->root);
if (!s.ok()) {
break;
}
}
return s;
}
Status GraphMgr::Register(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options, string* handle) {
Item* item = new Item;
Status s = InitItem(session, gdef, graph_options, item);
if (!s.ok()) {
item->Unref();
return s;
}
// Inserts one item into table_.
{
mutex_lock l(mu_);
*handle = strings::Printf("%016llx", ++next_id_);
item->handle = *handle;
CHECK(table_.insert({*handle, item}).second);
}
return Status::OK();
}
Status GraphMgr::Deregister(const string& handle) {
Item* item = nullptr;
// Removes one item from table_.
{
mutex_lock l(mu_);
auto iter = table_.find(handle);
if (iter == table_.end()) {
return errors::Aborted("Graph handle is not found: ", handle,
". Possibly, this worker just restarted.");
}
item = iter->second;
table_.erase(iter);
}
item->Unref();
return Status::OK();
}
Status GraphMgr::DeregisterAll() {
std::vector<Item*> items;
// Removes all items from table_.
{
mutex_lock l(mu_);
for (const auto& entry : table_) {
items.push_back(entry.second);
}
table_.clear();
}
for (auto item : items) {
item->Unref();
}
return Status::OK();
}
Status GraphMgr::Execute(const string& handle, const int64 step_id,
const ExecutorOpts& opts,
StepStatsCollector* collector,
CancellationManager* cancellation_manager,
const NamedTensors& in, NamedTensors* out) {
Notification n;
Status status;
ExecuteAsync(handle, step_id, opts, collector, cancellation_manager, in, out,
[&n, &status](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
}
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts,
StepStatsCollector* collector,
CancellationManager* cancellation_manager,
const NamedTensors& in, NamedTensors* out,
StatusCallback done) {
// Lookup an item. Holds one ref while executing.
Item* item = nullptr;
{
mutex_lock l(mu_);
auto iter = table_.find(handle);
if (iter != table_.end()) {
item = iter->second;
item->Ref();
}
}
if (item == nullptr) {
done(errors::Aborted("Graph handle is not found: ", handle));
return;
}
const int num_units = item->units.size();
CHECK_GE(num_units, 1);
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
// Sends values specified by the caller.
for (const auto& p : in) {
const string& key = p.first;
const Tensor& val = p.second;
const Status s = rendezvous->Send(key, Rendezvous::Args(), val, false);
if (!s.ok()) {
done(s);
item->Unref();
rendezvous->Unref();
return;
}
}
// Starts parallel Executors.
//
// NOTE: Transfer one ref of rendezvous and one ref of item to
// RunAllDone.
ExecutorBarrier* barrier = new ExecutorBarrier(
num_units, rendezvous, std::bind(&ME::RunAllDone, this, item, rendezvous,
out, done, std::placeholders::_1));
Executor::Args args;
{
mutex_lock l(mu_);
args.step_id = ++next_id_;
}
args.rendezvous = rendezvous;
args.cancellation_manager = cancellation_manager;
args.stats_collector = collector;
VLOG(1) << "Step " << args.step_id << " is for handle " << handle
<< ", graph-local step " << step_id;
thread::ThreadPool* pool = worker_env_->compute_pool;
args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
for (const auto& unit : item->units) {
unit.root->RunAsync(args, barrier->Get());
}
}
void GraphMgr::RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out,
StatusCallback done, Status s) {
if (s.ok()) {
// Receives values requested by the caller.
for (auto& p : *out) {
const string& key = p.first;
Tensor* val = &p.second;
bool is_dead = false;
s = rendezvous->Recv(key, Rendezvous::Args(), val, &is_dead);
if (is_dead) {
s = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
if (!s.ok()) break;
}
}
done(s);
rendezvous->Unref();
item->Unref();
}
} // end namespace tensorflow

View File

@ -0,0 +1,147 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
#include <unordered_map>
#include <vector>
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/config.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class ExecutorOpts;
class StepStatsCollector;
// GraphMgr keeps track of a set of graphs that are registered with a
// TensorFlow worker. Each registered graph is identified by a handle
// that is generated by GraphMgr and returned to the caller.
//
// After a successful registration, the caller executes a graph using
// the graph handle. Each execution is distinguished from others by a
// caller generated global unique id "step_id". Multiple executions
// can use the same graph concurrently and independently as long as
// "step_id" used are different.
//
// Multiple threads can call GraphMgr methods concurrently.
//
// E.g.,
// GraphMgr gmgr(worker_env);
// string handle;
// TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b },
// &handle));
// GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) },
// { "b", Tensor({3, 4}) } };
// GraphMgr::NamedTensors out = { { "c", Tensor() } };
// TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out));
// EXPECT_EQ(out["c"], Tensor({4, 6}));
class GraphMgr {
public:
explicit GraphMgr(const WorkerEnv* worker_env);
~GraphMgr();
// Registers a graph. Fills in "handle"
Status Register(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options, string* handle);
// Executes one step of a registered graph "handle".
//
// If "out" is not nullptr, "out" specifies all keys the execution
// should receive upon finish.
typedef std::map<string, Tensor> NamedTensors;
typedef std::function<void(const Status&)> StatusCallback;
void ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts, StepStatsCollector* collector,
CancellationManager* cancellation_manager,
const NamedTensors& in, NamedTensors* out,
StatusCallback done);
// Synchronous wrapper.
Status Execute(const string& handle, const int64 step_id,
const ExecutorOpts& opts,
StepStatsCollector* step_stats_collector,
CancellationManager* cancellation_manager,
const NamedTensors& in, NamedTensors* out);
// Deregisters a graph.
Status Deregister(const string& handle);
// Deregister all graphs.
Status DeregisterAll();
private:
typedef GraphMgr ME;
struct ExecutionUnit {
Device* device = nullptr;
Executor* root = nullptr;
FunctionLibraryRuntime* lib = nullptr;
};
struct Item : public core::RefCounted {
// TOOD(zhifengc): Keeps a copy of the original graph if the need arises.
// TOOD(zhifengc): Stats, updated by multiple runs potentially.
// TOOD(zhifengc): Dup-detection. Ensure step_id only run once.
~Item() override;
// Session handle.
string session;
// Graph handle.
string handle;
// The definition of the library is shared by all partitions.
FunctionLibraryDefinition* lib_def = nullptr;
// A graph is partitioned over multiple devices. Each partition
// has a root executor which may call into the runtime library.
std::vector<ExecutionUnit> units;
};
// Not owned.
const WorkerEnv* worker_env_;
// Owned.
mutex mu_;
int64 next_id_ GUARDED_BY(mu_) = 0;
// Table mapping graph handles to registered graphs.
//
// TODO(zhifengc): If the client does not call Deregister, we'll
// lose memory over time. We should implement a timeout-based
// mechanism to gc these graphs.
std::unordered_map<string, Item*> table_;
void RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out,
StatusCallback done, Status run_status);
Status InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options, Item* item);
TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
};
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_

View File

@ -0,0 +1,413 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Master implements the service MasterSerivce.
//
// A Master maintains the state of live graph computation
// sessions, each session orchestrates both local and remote devices
// to carry out the graph computation.
//
// A Master knows ahead of time local devices available as
// client devices.
//
// A Master discovers remote devices on-demand and keeps track of
// statistics of those remote devices.
//
// Each session analyses the graph, places nodes across available
// devices, and ultimately drives the graph computation by initiating
// RunGraph on the workers.
#include "tensorflow/core/distributed_runtime/master.h"
#include <unordered_set>
#include <vector>
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
Master::Master(MasterEnv* env, double session_gc_seconds)
: env_(env),
last_1000_steps_(1000),
step_count_(0),
session_gc_seconds_(session_gc_seconds) {
// Right now, a master service must be co-located with a device.
// Otherwise, fetches do not work.
CHECK(!env->local_devices.empty());
if (session_gc_seconds_ > 0.0) {
SchedClosure([this]() { GC(); });
}
}
Master::~Master() {
{
mutex_lock l(mu_);
shutdown_ = true;
shutdown_cv_.notify_all();
}
gc_stopped_.WaitForNotification();
}
void Master::GC() {
Env* env = Env::Default();
while (true) {
mutex_lock l(mu_);
const int kTimeoutMilliseconds = 10 * 1000; // 10 seconds.
WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);
if (shutdown_) {
break;
}
std::vector<string> handles;
const int64 num_micros = static_cast<int64>(session_gc_seconds_ * 1000000);
for (const auto& entry : sessions_) {
auto lat = entry.second->last_access_time_usec();
if (env->NowMicros() - lat > num_micros) {
handles.push_back(entry.first);
auto* sess = entry.second;
SchedClosure([this, sess]() {
LOG(WARNING) << "GC session " << sess->handle() << " after "
<< session_gc_seconds_ << " seconds. "
<< "Note that if you are starting multiple replicas "
<< "on a staggered delay, session_gc_seconds may need "
<< "to be raised.";
sess->Close();
});
}
}
for (const auto& handle : handles) sessions_.erase(handle);
}
gc_stopped_.Notify();
}
class DeviceFinder {
public:
explicit DeviceFinder(
const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env)
: env_(env) {
auto process_filter = [this](const string& filter) {
DeviceNameUtils::ParsedName parsed;
if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
filters_.push_back(parsed);
} else {
LOG(FATAL) << "Skipping invalid filter: " << filter;
}
};
for (const string& filter : device_filters) {
process_filter(filter);
}
}
~DeviceFinder() {
for (Device* dev : found_) delete dev;
}
void Start() {
// Enumerates all known workers' target. A target name is a
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
std::vector<string> workers;
env_->worker_cache->ListWorkers(&workers);
std::vector<string> targets;
if (filters_.empty()) {
swap(workers, targets);
} else {
for (const string& name : workers) {
if (MatchFilters(name)) {
targets.push_back(name);
}
}
}
{
mutex_lock l(mu_);
num_pending_ = targets.size();
if (num_pending_ == 0) {
pending_zero_.notify_all();
}
}
// Talk to all workers to get the list of available devices.
using std::placeholders::_1;
using std::placeholders::_2;
for (size_t i = 0; i < targets.size(); ++i) {
NewRemoteDevices(env_->env, env_->worker_cache, targets[i],
std::bind(&ME::WhenFound, this, _1, _2));
}
}
void Wait() {
mutex_lock l(mu_);
while (num_pending_ != 0) {
pending_zero_.wait(l);
}
}
// The caller takes the ownership of returned remote devices.
void GetRemoteDevices(const std::vector<Device*>& local,
std::vector<Device*>* remote) {
std::unordered_set<string> names(local.size());
for (Device* dev : local) names.insert(dev->name());
mutex_lock l(mu_);
for (Device* dev : found_) {
const string& name = dev->name();
if (names.insert(name).second && MatchFilters(name)) {
remote->push_back(dev);
} else {
delete dev;
}
}
found_.clear();
}
private:
typedef DeviceFinder ME;
const MasterEnv* env_;
std::vector<DeviceNameUtils::ParsedName> filters_;
mutex mu_;
int num_pending_ GUARDED_BY(mu_);
condition_variable pending_zero_;
std::vector<Device*> found_ GUARDED_BY(mu_);
void WhenFound(const Status& s, std::vector<Device*>* devices) {
mutex_lock l(mu_);
if (!s.ok()) {
LOG(ERROR) << "Master init: " << s;
} else {
found_.insert(found_.end(), devices->begin(), devices->end());
devices->clear();
}
--num_pending_;
if (num_pending_ == 0) {
pending_zero_.notify_all();
}
}
// Returns true iff the set of devices allowed by 'x' intersects
// with the set of devices allowed by 'y'.
bool Intersects(const DeviceNameUtils::ParsedName& x,
const DeviceNameUtils::ParsedName& y) {
return (!x.has_job || !y.has_job || x.job == y.job) &&
(!x.has_replica || !y.has_replica || x.replica == y.replica) &&
(!x.has_task || !y.has_task || x.task == y.task) &&
(!x.has_type || !y.has_type || x.type == y.type) &&
(!x.has_id || !y.has_id || x.id == y.id);
}
// Returns true iff 'name' matches one of the filters_.
bool MatchFilters(const string& name) {
if (filters_.empty()) return true;
DeviceNameUtils::ParsedName x;
if (DeviceNameUtils::ParseFullName(name, &x)) {
for (const auto& filter : filters_) {
if (Intersects(x, filter)) return true;
}
}
return false;
}
TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
};
void Master::CreateSession(const CreateSessionRequest* req,
CreateSessionResponse* resp, MyClosure done) {
SchedClosure([this, req, resp, done]() {
Status status = ValidateExternalGraphDefSyntax(req->graph_def());
if (status.ok()) {
// Ping all the workers and build the list of devices that the
// session will use.
DeviceFinder finder(req->config().device_filters(), env_);
finder.Start();
finder.Wait();
std::vector<Device*> remote_devices;
finder.GetRemoteDevices(env_->local_devices, &remote_devices);
SessionOptions options;
options.config = req->config();
MasterSessionInterface* session =
env_->master_session_factory(options, env_, &remote_devices);
GraphDef* gdef =
const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
Status create_status = session->Create(gdef);
if (!create_status.ok()) {
done(create_status);
return;
}
resp->set_session_handle(session->handle());
// Insert into the session map.
{
mutex_lock l(mu_);
CHECK(sessions_.insert({session->handle(), session}).second);
}
}
done(status);
});
}
void Master::ExtendSession(const ExtendSessionRequest* req,
ExtendSessionResponse* resp, MyClosure done) {
mu_.lock();
MasterSessionInterface* session = nullptr;
session = gtl::FindPtrOrNull(sessions_, req->session_handle());
if (session == nullptr) {
mu_.unlock();
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
return;
}
SchedClosure([session, req, resp, done]() {
Status status = ValidateExternalGraphDefSyntax(req->graph_def());
if (status.ok()) {
status = session->Extend(req, resp);
}
done(status);
});
mu_.unlock();
}
void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp, MyClosure done) {
mu_.lock();
uint64 start_time = env_->env->NowMicros();
MasterSessionInterface* session =
gtl::FindPtrOrNull(sessions_, req->session_handle());
if (session == nullptr) {
mu_.unlock();
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
return;
}
SchedClosure([this, start_time, session, opts, req, resp, done]() {
Status status = session->Run(opts, req, resp);
uint64 done_time = env_->env->NowMicros();
done(status);
mutex_lock l(mu_);
last_1000_steps_.AddValue((done_time - start_time) / 1e9);
++step_count_;
});
mu_.unlock();
}
void Master::CloseSession(const CloseSessionRequest* req,
CloseSessionResponse* resp, MyClosure done) {
MasterSessionInterface* session = nullptr;
{
mu_.lock();
auto iter = sessions_.find(req->session_handle());
if (iter == sessions_.end()) {
mu_.unlock();
done(errors::Aborted(
"Session ", req->session_handle(),
" is not found. Possibly, this master has restarted."));
return;
}
session = iter->second;
sessions_.erase(iter);
mu_.unlock();
}
// Session Close() blocks on thread shutdown. Therefore, we need to
// delete it in non-critical thread.
SchedClosure([session, done]() {
Status s = session->Close();
done(s);
});
}
void Master::ListDevices(const ListDevicesRequest* req,
ListDevicesResponse* resp, MyClosure done) {
SchedClosure([this, req, resp, done]() {
DeviceFinder finder({}, env_);
finder.Start();
finder.Wait();
std::vector<Device*> remote_devices;
finder.GetRemoteDevices(env_->local_devices, &remote_devices);
for (Device* dev : env_->local_devices) {
*(resp->add_local_device()) = dev->attributes();
}
for (Device* dev : remote_devices) {
*(resp->add_remote_device()) = dev->attributes();
delete dev;
}
done(Status::OK());
});
}
void Master::CleanupWorkers(const ResetRequest& reset) {
std::vector<string> worker_names;
env_->worker_cache->ListWorkers(&worker_names);
if (!worker_names.empty()) {
const int num_workers = worker_names.size();
std::vector<Notification> n(num_workers);
CleanupAllRequest req;
(*req.mutable_container()) = reset.container();
std::vector<CleanupAllResponse> resp(num_workers);
int c = 0;
for (int i = 0; i < num_workers; ++i) {
auto worker = env_->worker_cache->CreateWorker(worker_names[i]);
if (worker) {
worker->CleanupAllAsync(&req, &resp[i], [&n, worker, c](Status s) {
TF_CHECK_OK(s);
delete worker;
n[c].Notify();
});
} else {
n[c].Notify();
}
++c;
}
for (int i = 0; i < n.size(); ++i) {
n[i].WaitForNotification();
}
}
}
void Master::Reset(const ResetRequest* req, ResetResponse* resp,
MyClosure done) {
// Vector to hold the session pointers present in the sessions_
// (string->Session*) map.
std::vector<MasterSessionInterface*> sessions;
{
mutex_lock l(mu_);
for (const auto& entry : sessions_) {
sessions.push_back(entry.second);
}
sessions_.clear();
}
CleanupWorkers(*req);
SchedClosure([sessions, done]() {
Status s;
for (MasterSessionInterface* session : sessions) {
s.Update(session->Close());
}
done(s);
});
}
} // end namespace tensorflow

View File

@ -0,0 +1,98 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
#include <unordered_map>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/master_session_interface.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
class Master {
public:
explicit Master(MasterEnv* env, double session_gc_seconds);
virtual ~Master();
// Convenient typedef for a closure passing a Status.
typedef std::function<void(const Status&)> MyClosure;
void CreateSession(const CreateSessionRequest* req,
CreateSessionResponse* resp, MyClosure done);
void ExtendSession(const ExtendSessionRequest* req,
ExtendSessionResponse* resp, MyClosure done);
void RunStep(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp, MyClosure done);
void CloseSession(const CloseSessionRequest* req, CloseSessionResponse* resp,
MyClosure done);
void ListDevices(const ListDevicesRequest* req, ListDevicesResponse* resp,
MyClosure done);
void Reset(const ResetRequest* req, ResetResponse* resp, MyClosure done);
private:
typedef Master ME;
// Not owned.
MasterEnv* env_ = nullptr;
// Owned.
mutex mu_;
// shutdown_ is set to true by the dtor.
condition_variable shutdown_cv_;
bool shutdown_ GUARDED_BY(mu_) = false;
Notification gc_stopped_;
// Maps session handles to sessions.
std::unordered_map<string, MasterSessionInterface*> sessions_ GUARDED_BY(mu_);
// Moving average of step times.
MovingAverage last_1000_steps_ GUARDED_BY(mu_);
// Cumulative number of steps executed.
int64 step_count_ GUARDED_BY(mu_);
// If a session is not active for this many seconds, it will be
// closed automatically.
const double session_gc_seconds_;
// Call CleanupAll on all workers.
void CleanupWorkers(const ResetRequest& reset);
// Cleanup unused session.
void GC();
TF_DISALLOW_COPY_AND_ASSIGN(Master);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_

View File

@ -0,0 +1,66 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
#include <functional>
#include <vector>
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class Device;
class Env;
class MasterSessionInterface;
class OpRegistryInterface;
class WorkerCacheInterface;
// The master environment class, which holds a bag of pointers to
// per-master state.
//
// MasterEnv does not own its member pointers.
struct MasterEnv {
Env* env = nullptr;
// Object from which WorkerInterface instances can be obtained.
WorkerCacheInterface* worker_cache = nullptr;
// The operation definitions to use. Must be filled before use.
const OpRegistryInterface* ops = nullptr;
// Local devices co-located with this master. Devices are not owned
// by the master service.
//
// REQUIRES: !local_devices.empty().
std::vector<Device*> local_devices;
// Factory for creating master sessions, given session options and a
// vector of devices.
//
// The caller of the function takes ownership of the returned
// `MasterSessionInterface`, which may not be null. Ownership of the
// `MasterEnv*` is retained by the caller. The callee takes
// ownership of the `std::vector<Device*>*` argument, but does not
// take ownership of the `Device*` objects in the vector.
std::function<MasterSessionInterface*(const SessionOptions&, MasterEnv*,
std::vector<Device*>*)>
master_session_factory;
};
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_

View File

@ -0,0 +1,52 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/master.pb.h"
namespace tensorflow {
// Pure virtual interface for communicating with the TensorFlow Master service.
//
// This interface is intended to support in-process master
// implementations that do not require an RPC roundtrip.
class MasterInterface {
public:
virtual ~MasterInterface() {}
virtual Status CreateSession(const CreateSessionRequest* request,
CreateSessionResponse* response) = 0;
virtual Status ExtendSession(const ExtendSessionRequest* request,
ExtendSessionResponse* response) = 0;
virtual Status RunStep(const RunStepRequest* request,
RunStepResponse* response) = 0;
virtual Status CloseSession(const CloseSessionRequest* request,
CloseSessionResponse* response) = 0;
virtual Status ListDevices(const ListDevicesRequest* request,
ListDevicesResponse* response) = 0;
virtual Status Reset(const ResetRequest* request,
ResetResponse* response) = 0;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_

View File

@ -0,0 +1,942 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/master_session.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/master_session_interface.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/simple_graph_execution_state.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace {
// A little bit of per-step state.
struct PerStepState {
Microseconds start_micros = Microseconds(0);
Microseconds end_micros = Microseconds(0);
std::vector<StepStats> step_stats; // per partition
};
// A session encapsulates a graph computation (resource allocation,
// placement, execution, etc.).
class MasterSession : public MasterSessionInterface {
public:
// This session encapsulates the graph computation for a graph.
//
// The session places nodes on devices in "remote_devs" and executes
// operations on these devices.
//
// The caller takes ownership of all remote devices.
MasterSession(const SessionOptions& options, const MasterEnv* env,
std::vector<Device*>* remote_devs);
// Initialize the Session for "def". Must be called before Extend(),
// Run(), or Close().
//
// The callee may clear "def".
Status Create(GraphDef* def) override;
// Returns the session handle.
const string& handle() const override { return handle_; }
// Returns the last access time (the number of micro-seconds since
// some fixed point in time) of this session.
uint64 last_access_time_usec() const override {
return last_access_time_usec_.load();
}
// Attempt to extend the graph according to the given "req".
// (See master.proto for details of valid extensions.)
//
// PRECONDITION: The current version of this session's graph
// is "req->current_graph_version".
//
// POSTCONDITION: The current version of this session's graph
// is "resp->new_graph_version".
//
// Extend() may block the caller thread for a long time.
Status Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) override;
// Run one step.
Status Run(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp) override;
// Close this session and delete "*this". Returns OK if all known
// states are cleanup successfully.
//
// Close() may block the caller thread for a long time.
Status Close() override;
private:
SessionOptions session_opts_;
// Not owned.
const MasterEnv* env_;
// The opaque session handle.
const string handle_;
// Owned.
std::vector<Device*> remote_devs_;
// The device set used by this session.
DeviceSet devices_;
// TODO(zhifengc): Support Extend().
//
// 'func_def_lib_' is a copy of the initial graph def's library.
// 'flib_def_' is an index structure of "func_def_lib_' keyed by
// function names.
FunctionDefLibrary func_def_lib_;
FunctionLibraryDefinition* flib_def_ = nullptr;
std::atomic_ulong last_access_time_usec_;
mutex mu_;
std::unique_ptr<SimpleGraphExecutionState> execution_state_;
int64 graph_version_;
int32 steps_since_last_scheduling_ GUARDED_BY(mu_) = 0;
int32 scheduling_period_steps_ GUARDED_BY(mu_) = 10;
// We keep a map from a signature of a run request to the
// ReffedClientGraph the can execute it. We keep up to one old copy
// of each ReffedClientGraph around because if it gets deallocated
// before a new substitute has been created, Variables can go out of
// scope and lose their state.
class ReffedClientGraph;
typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
RCGMap runs_ GUARDED_BY(mu_);
RCGMap obsolete_ GUARDED_BY(mu_);
// Active RunStep calls.
condition_variable num_running_is_zero_;
int32 num_running_ GUARDED_BY(mu_) = 0;
std::unordered_map<uint64, int64> subgraph_execution_counts_ GUARDED_BY(mu_);
// We need to ensure that certain nodes added (e.g., send and recv
// nodes) are unique across all sub-graphs within this session.
int64 next_node_id_ GUARDED_BY(mu_) = 0;
// Private dtor. The client must call Close().
virtual ~MasterSession();
Status StartStep(const RunStepRequest& req, BuildGraphOptions* opts,
int64* count, ReffedClientGraph** graph);
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status DoRunWithLocalExecution(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp);
void UpdateLastAccessTime();
TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
};
// Session wraps ClientGraph in a reference counted object. This way,
// Session can clear up the cache mapping Run requests to compiled
// graphs while the compiled graph is still being used.
//
// TODO(zhifengc): Cleanup this class. It's becoming messy.
class MasterSession::ReffedClientGraph : public core::RefCounted {
public:
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
ClientGraph* cg, const GraphOptions& graph_opts)
: session_handle_(handle),
client_graph_(cg),
bopts_(bopts),
graph_opts_(graph_opts) {
VLOG(1) << "Created ReffedClientGraph for node with "
<< client_graph_->graph.num_node_ids();
const string key =
strings::StrCat("{", str_util::Join(bopts.feed_endpoints, ","), "},{",
str_util::Join(bopts.target_nodes, ","), "},{",
str_util::Join(bopts.fetch_endpoints, ","), "}");
// TODO(mrry): Publish information about the graph (such as
// timelines, the pruned graph, statistics, etc.).
}
~ReffedClientGraph() override {
delete client_graph_;
DeregisterPartitions();
}
const ClientGraph* client_graph() { return client_graph_; }
// Local execution methods.
// Partitions the graph into subgraphs and registers them on
// workers.
Status RegisterPartitions(const MasterEnv* env, const PartitionOptions& popts,
const FunctionDefLibrary& func_def_lib);
// Runs one step of all partitions.
Status RunPartitions(const MasterEnv* env, int64 step_id,
int64 execution_count,
SimpleGraphExecutionState* execution_state,
PerStepState* pss, CallOptions* opts,
const RunStepRequest& req, RunStepResponse* resp);
// Calls workers to cleanup states for the step "step_id". Waits
// till all cleanup rpcs complete.
Status CleanupPartitions(int64 step_id);
// TODO(mrry): Runtime statistics collection.
private:
const string session_handle_;
ClientGraph* const client_graph_ = nullptr;
std::unordered_set<const Node*> nodes_needing_input_mapping_;
BuildGraphOptions bopts_;
const GraphOptions graph_opts_;
// Graph partitioned into per-location subgraphs.
struct Part {
// Worker name.
string name;
// Graph definition.
GraphDef gdef;
// Maps feed names to rendezvous keys. Empty most of the time.
std::unordered_map<string, string> feed_key;
// Maps rendezvous keys to fetch names. Empty most of the time.
std::unordered_map<string, string> key_fetch;
// The interface to the worker. Owned.
WorkerInterface* worker = nullptr;
// After registeration with the worker, graph_handle identifies
// this partition on the worker.
string graph_handle;
Part() : feed_key(3), key_fetch(3) {}
};
// partitions_ is immutable after RegisterPartitions() call
// finishes. RunPartitions() can access partitions_ safely without
// acquring locks.
std::vector<Part> partitions_;
mutable mutex mu_;
// Partition initialization and registration only needs to happen
// once. init_started_ && !init_done_ indicates the initialization
// is on going.
bool init_started_ GUARDED_BY(mu_) = false;
Notification init_done_;
// init_result_ remembers the initialization error if any.
Status init_result_ GUARDED_BY(mu_);
// Send/Recv nodes that are the result of client-added
// feeds and fetches must be tracked so that the tensors
// can be be added to the local rendezvous.
static void TrackFeedsAndFetches(Part* part, const PartitionOptions& popts);
// The actual graph partitioning and registration implementation.
Status DoRegisterPartitions(const MasterEnv* env,
const PartitionOptions& popts,
const FunctionDefLibrary& func_def_lib);
// Deregisters the partitions on the workers. Called in the
// destructor and does not wait for the rpc completion.
void DeregisterPartitions();
TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
};
Status MasterSession::ReffedClientGraph::RegisterPartitions(
const MasterEnv* env, const PartitionOptions& popts,
const FunctionDefLibrary& func_def_lib) {
{ // Ensure register once.
mu_.lock();
if (!init_started_) {
init_started_ = true;
mu_.unlock();
Status s = DoRegisterPartitions(env, popts, func_def_lib);
mu_.lock();
init_result_ = s;
init_done_.Notify();
} else {
mu_.unlock();
init_done_.WaitForNotification();
mu_.lock();
}
Status result = init_result_;
mu_.unlock();
return result;
}
}
static string SplitByWorker(const Node* node) {
string task;
string device;
CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
&device))
<< "node: " << node->name() << " dev: " << node->assigned_device_name();
return task;
}
void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
Part* part, const PartitionOptions& popts) {
for (int i = 0; i < part->gdef.node_size(); ++i) {
NodeDef* ndef = part->gdef.mutable_node(i);
const bool is_recv = ndef->op() == "_Recv";
const bool is_send = ndef->op() == "_Send";
if (is_recv || is_send) {
string name;
TF_CHECK_OK(GetNodeAttr(*ndef, "tensor_name", &name));
string send_device;
TF_CHECK_OK(GetNodeAttr(*ndef, "send_device", &send_device));
string recv_device;
TF_CHECK_OK(GetNodeAttr(*ndef, "recv_device", &recv_device));
uint64 send_device_incarnation;
TF_CHECK_OK(
GetNodeAttr(*ndef, "send_device_incarnation",
reinterpret_cast<int64*>(&send_device_incarnation)));
const string& key =
Rendezvous::CreateKey(send_device, send_device_incarnation,
recv_device, name, FrameAndIter(0, 0));
// Only send/recv nodes that were added as feeds and fetches
// (client-terminated) should be tracked. Other send/recv nodes
// are for transferring data between partitions / memory spaces.
bool client_terminated;
TF_CHECK_OK(GetNodeAttr(*ndef, "client_terminated", &client_terminated));
if (client_terminated) {
if (is_recv) {
part->feed_key.insert({name, key});
} else {
part->key_fetch.insert({key, name});
}
}
}
}
}
Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
const MasterEnv* env, const PartitionOptions& popts,
const FunctionDefLibrary& func_def_lib) {
// Partition the graph.
Status s;
std::unordered_map<string, GraphDef> graph_partitions;
s = Partition(popts, &client_graph_->graph, &graph_partitions);
if (!s.ok()) return s;
partitions_.reserve(graph_partitions.size());
for (auto& name_def : graph_partitions) {
partitions_.resize(partitions_.size() + 1);
Part* part = &partitions_.back();
part->name = name_def.first;
part->gdef.Swap(&name_def.second);
// For simplicity, we ship the library completely to every worker.
*(part->gdef.mutable_library()) = func_def_lib;
TrackFeedsAndFetches(part, popts);
part->worker = env->worker_cache->CreateWorker(part->name);
if (part->worker == nullptr) {
s = errors::NotFound("worker ", part->name);
break;
}
}
if (!s.ok()) {
for (Part& part : partitions_) {
delete part.worker;
}
return s;
}
struct Call {
RegisterGraphRequest req;
RegisterGraphResponse resp;
Status status;
Notification done;
};
const int num = partitions_.size();
gtl::InlinedVector<Call, 4> calls(num);
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
Call* c = &calls[i];
c->req.set_session_handle(session_handle_);
*c->req.mutable_graph_def() = part.gdef;
*c->req.mutable_graph_options() = graph_opts_;
VLOG(2) << "Register " << part.gdef.DebugString();
auto cb = [c](const Status& s) {
c->status = s;
c->done.Notify();
};
part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
}
for (int i = num - 1; i >= 0; --i) {
Call* c = &calls[i];
c->done.WaitForNotification();
s.Update(c->status);
partitions_[i].graph_handle = c->resp.graph_handle();
}
return s;
}
static bool CopyIfNeeded(TensorProto* in, TensorProto* out) {
if (in->tensor_content().empty()) {
// If the tensor is not encoded in tensor_content or contains 0
// elements, we can return it to the client directly.
out->Swap(in);
} else {
Tensor t(in->dtype());
if (!t.FromProto(cpu_allocator(), *in)) return false;
t.AsProtoField(out);
}
return true;
}
// Helper class to manage "num" parallel RunGraph calls.
class RunManyGraphs {
public:
explicit RunManyGraphs(int num) : calls_(num), num_pending_(num) {}
~RunManyGraphs() {}
// Returns the index-th call.
struct Call {
CallOptions opts;
RunGraphRequest req;
RunGraphResponse resp;
};
Call* get(int index) { return &calls_[index]; }
// When the index-th call is done, updates the overall status.
void WhenDone(int index, const Status& s) {
TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
{
mutex_lock l(mu_);
if (!s.ok()) {
UpdateStatusLocked(s);
}
--num_pending_;
cv_pending_.notify_all();
}
}
void StartCancel() {
mutex_lock l(mu_);
UpdateStatusLocked(errors::Cancelled("RunManyGraphs"));
}
void Wait() {
mutex_lock l(mu_);
while (num_pending_ > 0) {
cv_pending_.wait(l);
}
}
Status status() const {
mutex_lock l(mu_);
return status_;
}
private:
gtl::InlinedVector<Call, 4> calls_;
// TODO(jeff,sanjay): Replace bookkeeping state here with a
// BlockingCounter abstraction that we define in
// tensorflow/core/lib/core.
mutable mutex mu_;
condition_variable cv_pending_;
int num_pending_;
Status status_ GUARDED_BY(mu_);
void UpdateStatusLocked(const Status& s) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (status_.ok()) {
status_ = s;
for (Call& call : calls_) {
call.opts.StartCancel();
}
}
}
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
};
Status MasterSession::ReffedClientGraph::RunPartitions(
const MasterEnv* env, int64 step_id, int64 execution_count,
SimpleGraphExecutionState* execution_state, PerStepState* pss,
CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp) {
VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
<< execution_count;
// Builds an index for feeds provided by the client.
std::unordered_map<StringPiece, const TensorProto*, StringPiece::Hasher>
feeds(3);
for (const auto& feed : req.feed()) {
if (!feeds.insert({feed.name(), &feed.tensor()}).second) {
return errors::InvalidArgument("Duplicated feeds: ", feed.name());
}
}
// Prepares a number of calls to workers. One call per partition.
ExecutorOpts exec_opts;
const int num = partitions_.size();
RunManyGraphs calls(num);
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
RunManyGraphs::Call* c = calls.get(i);
c->req.set_graph_handle(part.graph_handle);
c->req.set_step_id(step_id);
*c->req.mutable_exec_opts() = exec_opts;
// If any feeds are provided, send the feed values together
// in the RunGraph request.
for (const auto& feed_key : part.feed_key) {
const string& feed = feed_key.first;
const string& key = feed_key.second;
const TensorProto* val = feeds[feed];
if (val == nullptr) {
return errors::InvalidArgument("No feed is provided for feed=", feed,
", key=", key);
}
auto* send = c->req.add_send();
send->set_key(key);
*(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
}
for (const auto& key_fetch : part.key_fetch) {
const string& key = key_fetch.first;
c->req.add_recv_key(key);
}
}
// Issues RunGraph calls.
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
RunManyGraphs::Call* call = calls.get(i);
TRACEPRINTF("Partition %d %s", i, part.name.c_str());
part.worker->RunGraphAsync(
&call->opts, &call->req, &call->resp,
std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
}
// Waits for the RunGraph calls.
call_opts->SetCancelCallback([&calls]() { calls.StartCancel(); });
calls.Wait();
call_opts->ClearCancelCallback();
// Collects fetches.
Status status = calls.status();
if (status.ok()) {
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
for (auto& recv : *(calls.get(i)->resp.mutable_recv())) {
auto* ret = resp->add_tensor();
auto iter = part.key_fetch.find(recv.key());
if (iter == part.key_fetch.end()) {
status.Update(errors::Internal("Unexpected fetch key: ", recv.key()));
break;
}
const string& fetch = iter->second;
ret->set_name(fetch);
if (!CopyIfNeeded(recv.mutable_val(), ret->mutable_tensor())) {
status.Update(
errors::Internal("Unexpected unparseable tensor: ", recv.key()));
break;
}
}
if (calls.get(i)->resp.has_step_stats()) {
pss->step_stats[i].Swap(calls.get(i)->resp.mutable_step_stats());
}
}
}
return status;
}
Status MasterSession::ReffedClientGraph::CleanupPartitions(int64 step_id) {
struct Call {
CleanupGraphRequest req;
CleanupGraphResponse resp;
Notification done;
Status status;
};
const int num = partitions_.size();
gtl::InlinedVector<Call, 4> calls(num);
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
Call* c = &calls[i];
c->req.set_step_id(step_id);
part.worker->CleanupGraphAsync(&c->req, &c->resp, [c](const Status& s) {
c->status = s;
c->done.Notify();
});
}
Status s;
for (int i = num - 1; i >= 0; --i) {
Call* c = &calls[i];
c->done.WaitForNotification();
s.Update(c->status);
}
return s;
}
// Makes async calls to workers without waiting deregistering subgraphs.
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
struct Call {
DeregisterGraphRequest req;
DeregisterGraphResponse resp;
};
for (Part& part : partitions_) {
Call* c = new Call;
c->req.set_graph_handle(part.graph_handle);
WorkerInterface* w = part.worker;
auto cb = [c, w](const Status& s) {
if (!s.ok()) {
// This error is potentially benign, so we don't log at the
// error level.
LOG(INFO) << "DeregisterGraph error: " << s;
}
delete c;
delete w;
};
w->DeregisterGraphAsync(&c->req, &c->resp, cb);
}
}
void BuildBuildGraphOptions(const RunStepRequest& req,
BuildGraphOptions* opts) {
for (const auto& feed : req.feed()) {
opts->feed_endpoints.push_back(feed.name());
}
for (const auto& fetch : req.fetch()) {
// TODO(touts): handle ref:
opts->fetch_endpoints.push_back(fetch);
}
for (const auto& target : req.target()) {
opts->target_nodes.push_back(target);
}
std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
}
uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
uint64 h = 0x2b992ddfa23249d6ull;
for (const string& name : opts.feed_endpoints) {
h = Hash64(name.c_str(), name.size(), h);
}
for (const string& name : opts.target_nodes) {
h = Hash64(name.c_str(), name.size(), h);
}
for (const string& name : opts.fetch_endpoints) {
h = Hash64(name.c_str(), name.size(), h);
}
return h;
}
string BuildGraphOptionsString(const BuildGraphOptions& opts) {
string buf;
for (const string& name : opts.feed_endpoints) {
strings::StrAppend(&buf, " FdE: ", name);
}
strings::StrAppend(&buf, "\n");
for (const string& name : opts.target_nodes) {
strings::StrAppend(&buf, " TN: ", name);
}
strings::StrAppend(&buf, "\n");
for (const string& name : opts.fetch_endpoints) {
strings::StrAppend(&buf, " FeE: ", name);
}
strings::StrAppend(&buf, "\n");
return buf;
}
MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
std::vector<Device*>* remote_devs)
: session_opts_(opt),
env_(env),
handle_(strings::FpToString(random::New64())),
graph_version_(0),
runs_(5) {
UpdateLastAccessTime();
swap(remote_devs_, *remote_devs);
VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
<< " #remote " << remote_devs_.size();
for (Device* d : remote_devs_) {
devices_.AddDevice(d);
}
int num_local_devices = 0;
for (Device* d : env->local_devices) {
devices_.AddDevice(d);
if (num_local_devices == 0) {
// Uses the first local device as the client device.
devices_.set_client_device(d);
}
num_local_devices++;
}
}
MasterSession::~MasterSession() {
for (const auto& iter : runs_) iter.second->Unref();
for (const auto& iter : obsolete_) iter.second->Unref();
delete flib_def_;
for (Device* dev : remote_devs_) delete dev;
}
void MasterSession::UpdateLastAccessTime() {
last_access_time_usec_.store(Env::Default()->NowMicros());
}
Status MasterSession::Create(GraphDef* graph_def) {
// Keeps a copy of graph_def->library() and flib_def_ serves the
// OpRegistryInterface used by the SimpleGraphExecutionState to construct the
// pre-partitioned graphs during DoRunWithLocalExecution().
func_def_lib_.Swap(graph_def->mutable_library());
flib_def_ = new FunctionLibraryDefinition(func_def_lib_);
SimpleGraphExecutionStateOptions options;
options.device_set = &devices_;
options.session_options = &session_opts_;
execution_state_.reset(new SimpleGraphExecutionState(flib_def_, options));
TF_RETURN_IF_ERROR(execution_state_->Create(graph_def));
return Status::OK();
}
Status MasterSession::Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) {
UpdateLastAccessTime();
std::unique_ptr<SimpleGraphExecutionState> old_execution_state;
{
mutex_lock l(mu_);
// TODO(mrry): Redesign the locking with reader/writer locks to prevent
// starvation due to concurrent steps being issued. This is not
// immediately important because we expect Extend to be used in
// development/interactive exploration, and not during high-throughput
// training.
while (num_running_ != 0) {
num_running_is_zero_.wait(l);
}
if (graph_version_ != req->current_graph_version()) {
return errors::Aborted("Current version is ", graph_version_,
" but caller expected ",
req->current_graph_version(), ".");
}
CHECK(execution_state_);
SimpleGraphExecutionState* extended_execution_state = nullptr;
Status s =
execution_state_->Extend(req->graph_def(), &extended_execution_state);
if (s.ok()) {
CHECK(extended_execution_state);
old_execution_state =
std::move(execution_state_); // Will be released outside the lock
execution_state_.reset(extended_execution_state);
++graph_version_;
resp->set_new_graph_version(graph_version_);
}
return s;
}
}
Status MasterSession::StartStep(const RunStepRequest& req,
BuildGraphOptions* opts, int64* count,
ReffedClientGraph** rcg) {
BuildBuildGraphOptions(req, opts);
const uint64 hash = HashBuildGraphOptions(*opts);
ReffedClientGraph* to_unref = nullptr;
{
mutex_lock l(mu_);
// Keep track of how many times this subgraph has been executed in
// this session.
int64* c = &subgraph_execution_counts_[hash];
*count = (*c)++;
auto iter = runs_.find(hash);
if (iter == runs_.end()) {
// We have not seen this subgraph before. Build the subgraph and
// cache it.
VLOG(1) << "Unseen hash " << hash << " for "
<< BuildGraphOptionsString(*opts);
ClientGraph* client_graph = nullptr;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(*opts, &client_graph));
auto entry = new ReffedClientGraph(handle_, *opts, client_graph,
session_opts_.config.graph_options());
iter = runs_.insert({hash, entry}).first;
auto obs_iter = obsolete_.find(hash);
if (obs_iter != obsolete_.end()) {
to_unref = obs_iter->second;
obsolete_.erase(obs_iter);
}
VLOG(1) << "Preparing to execute new graph";
}
*rcg = iter->second;
(*rcg)->Ref();
}
if (to_unref) to_unref->Unref();
return Status::OK();
}
void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
RCGMap* rcg_map) {
VLOG(1) << "Discarding all reffed graphs";
for (auto p : *rcg_map) {
ReffedClientGraph* rcg = p.second;
if (to_unref) {
to_unref->push_back(rcg);
} else {
rcg->Unref();
}
}
rcg_map->clear();
}
Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp) {
UpdateLastAccessTime();
{
mutex_lock l(mu_);
++num_running_;
}
Status status = DoRunWithLocalExecution(opts, req, resp);
{
mutex_lock l(mu_);
--num_running_;
if (num_running_ == 0) {
num_running_is_zero_.notify_all();
}
}
return status;
}
Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
const RunStepRequest* req,
RunStepResponse* resp) {
VLOG(2) << "DoRunWithLocalExecution "
<< "req: " << req->DebugString();
PerStepState pss;
pss.start_micros = Env::Default()->NowMicros();
// Prepare.
BuildGraphOptions bgopts;
ReffedClientGraph* rcg = nullptr;
int64 count = 0;
TF_RETURN_IF_ERROR(StartStep(*req, &bgopts, &count, &rcg));
// Unref "rcg" when out of scope.
core::ScopedUnref unref(rcg);
// Registers subgraphs if haven't done so.
PartitionOptions popts;
popts.node_to_loc = SplitByWorker;
popts.new_name = [this](const string& prefix) {
mutex_lock l(mu_);
return strings::StrCat(prefix, "_S", next_node_id_++);
};
popts.get_incarnation = [this](const string& name) {
Device* d = devices_.FindDeviceByName(name);
if (d == nullptr) {
return PartitionOptions::kIllegalIncarnation;
} else {
return d->attributes().incarnation();
}
};
popts.control_flow_added = false;
// TODO(mrry): Enable DT_BFLOAT16 casting.
// TODO(mrry): Enable recv scheduling.
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(env_, popts, func_def_lib_));
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
TRACEPRINTF("stepid %llu", step_id);
TF_RETURN_IF_ERROR(rcg->RunPartitions(
env_, step_id, count, execution_state_.get(), &pss, opts, *req, resp));
pss.end_micros = Env::Default()->NowMicros();
// Schedule post-processing and cleanup to be done async.
rcg->Ref();
// TODO(tucker): We're doing the stats processing prior to returning
// the response to the client. Ensure it's safe to do so, then schedule
// in a closure.
SchedClosure([this, rcg, step_id]() {
Status s = rcg->CleanupPartitions(step_id);
if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
});
return Status::OK();
}
Status MasterSession::Close() {
std::vector<ReffedClientGraph*> to_unref;
{
mutex_lock l(mu_);
while (num_running_ != 0) {
num_running_is_zero_.wait(l);
}
ClearRunsTable(&to_unref, &runs_);
ClearRunsTable(&to_unref, &obsolete_);
}
for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
delete this;
return Status::OK();
}
} // end namespace
namespace internal {
MasterSessionInterface* NewMasterSession(const SessionOptions& options,
const MasterEnv* env,
std::vector<Device*>* remote_devs) {
return new MasterSession(options, env, remote_devs);
}
} // end namespace internal
} // end namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
#include <vector>
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class Device;
class MasterEnv;
class MasterSessionInterface;
namespace internal {
MasterSessionInterface* NewMasterSession(const SessionOptions& options,
const MasterEnv* env,
std::vector<Device*>* remote_devs);
} // namespace internal
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_

View File

@ -0,0 +1,76 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
class ThreadPool;
namespace tensorflow {
class CallOptions;
class GraphDef;
class RunStepRequest;
class RunStepResponse;
class ExtendSessionRequest;
class ExtendSessionResponse;
// A "master session" encapsulates a distributed graph computation
// (resource allocation, placement, execution, etc.).
class MasterSessionInterface {
public:
// Initializes the Session with "def". Must be called before Extend(),
// Run(), or Close().
//
// The callee may clear "def".
virtual Status Create(GraphDef* def) = 0;
// Returns the session handle.
virtual const string& handle() const = 0;
// Returns the last access time (the number of micro-seconds since
// some fixed point in time) of this session.
virtual uint64 last_access_time_usec() const = 0;
// Attempt to extend the graph according to the given "req".
// (See master.proto for details of valid extensions.)
//
// PRECONDITION: The current version of this session's graph
// is "req->current_version".
//
// POSTCONDITION: The current version of this session's graph
// is "req->new_version".
//
// Extend() may block the caller thread for a long time.
virtual Status Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) = 0;
// Run one step.
virtual Status Run(CallOptions* opts, const RunStepRequest* req,
RunStepResponse* resp) = 0;
// Close this session and delete "*this". Returns OK if all known
// states are cleanup successfully.
//
// Close() may block the caller thread for a long time.
virtual Status Close() = 0;
};
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_

View File

@ -0,0 +1,423 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/master.h"
#include <map>
#include <memory>
#include "external/grpc/include/grpc++/grpc++.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
class MasterTest : public ::testing::Test {
protected:
MasterTest() {
std::vector<string> targets;
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 1;
(*options.config.mutable_device_count())["GPU"] = 0;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 2, &cluster_));
master_ = grpc::MasterService::NewStub(
NewHostPortGrpcChannel(cluster_->targets()[0]));
}
std::unique_ptr<test::TestCluster> cluster_;
std::unique_ptr<grpc::MasterService::Stub> master_;
// Helpers for MasterService.{CreateSession,RunStep,CloseSession}
// rpc calls.
Status CreateSession(const GraphDef& def, string* handle,
int64* initial_version) {
::grpc::ClientContext ctx;
CreateSessionRequest req;
*(req.mutable_graph_def()) = def;
// Invokes placement frequently.
req.mutable_config()->set_placement_period(1);
CreateSessionResponse resp;
const Status s = FromGrpcStatus(master_->CreateSession(&ctx, req, &resp));
if (s.ok()) {
*handle = resp.session_handle();
*initial_version = resp.graph_version();
}
return s;
}
Status ExtendSession(const string& handle, const GraphDef& def,
int64 current_version, int64* new_version) {
::grpc::ClientContext ctx;
ExtendSessionRequest req;
req.set_session_handle(handle);
*(req.mutable_graph_def()) = def;
req.set_current_graph_version(current_version);
ExtendSessionResponse resp;
const Status s = FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp));
if (s.ok()) {
*new_version = resp.new_graph_version();
}
return s;
}
Status RunStep(const string& handle,
const std::vector<std::pair<string, const Tensor*> >& feed,
const std::map<string, Tensor*>& fetch) {
::grpc::ClientContext ctx;
RunStepRequest req;
req.set_session_handle(handle);
for (const auto& p : feed) {
const string& feed_name = p.first;
const Tensor* feed_tensor = p.second;
auto f = req.add_feed();
f->set_name(feed_name);
feed_tensor->AsProtoTensorContent(f->mutable_tensor());
}
for (const auto& p : fetch) {
const string& fetch_name = p.first;
req.add_fetch(fetch_name);
}
RunStepResponse resp;
const Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp));
if (s.ok()) {
for (const auto& fetch_resp : resp.tensor()) {
auto it = fetch.find(fetch_resp.name());
CHECK(it != fetch.end());
CHECK(it->second->FromProto(fetch_resp.tensor()));
}
}
return s;
}
Status CloseSession(const string& handle) {
::grpc::ClientContext ctx;
CloseSessionRequest req;
req.set_session_handle(handle);
CloseSessionResponse resp;
return FromGrpcStatus(master_->CloseSession(&ctx, req, &resp));
}
Status Reset() {
::grpc::ClientContext ctx;
ResetRequest req;
ResetResponse resp;
return FromGrpcStatus(master_->Reset(&ctx, req, &resp));
}
};
TEST_F(MasterTest, CreateClose) {
GraphDef def; // Empty.
string handle;
int64 initial_version;
TF_ASSERT_OK(CreateSession(def, &handle, &initial_version));
EXPECT_TRUE(errors::IsAborted(CloseSession("randombits")));
EXPECT_TRUE(CloseSession(handle).ok());
}
TEST_F(MasterTest, ListDevices) {
::grpc::ClientContext ctx;
ListDevicesRequest req;
ListDevicesResponse resp;
const Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp));
TF_EXPECT_OK(s);
EXPECT_EQ(1, resp.local_device_size());
EXPECT_EQ("CPU", resp.local_device(0).device_type());
}
TEST_F(MasterTest, Reset) {
GraphDef def; // Empty.
string s1, s2;
int64 initial_version1, initial_version2;
TF_ASSERT_OK(CreateSession(def, &s1, &initial_version1));
TF_ASSERT_OK(CreateSession(def, &s2, &initial_version2));
EXPECT_TRUE(Reset().ok());
EXPECT_TRUE(errors::IsAborted(CloseSession(s1)));
EXPECT_TRUE(errors::IsAborted(CloseSession(s2)));
}
TEST_F(MasterTest, Extend) {
GraphDef def_0; // Empty.
string handle;
int64 initial_version;
TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
Tensor A_expected(DT_FLOAT, TensorShape({2, 2}));
test::FillValues<float>(&A_expected, {3.0, 2.0, -1.0, 0.0});
Tensor x_expected(DT_FLOAT, TensorShape({2, 1}));
test::FillValues<float>(&x_expected, {2.0, 2.0});
Graph graph_1(OpRegistry::Global());
test::graph::Constant(&graph_1, A_expected, "A");
GraphDef def_1;
test::graph::ToGraphDef(&graph_1, &def_1);
int64 version_1;
TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
EXPECT_GT(version_1, initial_version);
Tensor A(DT_FLOAT, TensorShape({2, 2}));
TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
test::ExpectTensorEqual<float>(A, A_expected);
Graph graph_2(OpRegistry::Global());
test::graph::Constant(&graph_2, x_expected, "x");
GraphDef def_2;
test::graph::ToGraphDef(&graph_2, &def_2);
int64 version_2;
EXPECT_TRUE(errors::IsAborted(
ExtendSession("randombits", def_2, version_1, &version_2)));
TF_ASSERT_OK(ExtendSession(handle, def_2, version_1, &version_2));
EXPECT_GT(version_2, version_1);
Tensor x(DT_FLOAT, TensorShape({2, 1}));
TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"x:0", &x}}));
test::ExpectTensorEqual<float>(A, A_expected);
test::ExpectTensorEqual<float>(x, x_expected);
TF_ASSERT_OK(CloseSession(handle));
}
TEST_F(MasterTest, ExtendUpdateStatefulFails) {
GraphDef def_0; // Empty.
string handle;
int64 initial_version;
TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
Graph graph_1(OpRegistry::Global());
test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
GraphDef def_1;
test::graph::ToGraphDef(&graph_1, &def_1);
int64 version_1, version_2;
TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
EXPECT_GT(version_1, initial_version);
EXPECT_TRUE(errors::IsInvalidArgument(
ExtendSession(handle, def_1, version_1, &version_2)));
TF_ASSERT_OK(CloseSession(handle));
}
TEST_F(MasterTest, ExtendTwiceFails) {
GraphDef def_0; // Empty.
string handle;
int64 initial_version;
TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
Graph graph_1(OpRegistry::Global());
test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
GraphDef def_1;
test::graph::ToGraphDef(&graph_1, &def_1);
int64 version_1;
TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
EXPECT_GT(version_1, initial_version);
EXPECT_TRUE(errors::IsAborted(
ExtendSession(handle, def_1, initial_version, &version_1)));
TF_ASSERT_OK(CloseSession(handle));
}
TEST_F(MasterTest, ConcurrentExtendOnlyOneSucceeds) {
GraphDef def_0; // Empty.
string handle;
int64 initial_version;
TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
Graph graph_1(OpRegistry::Global());
test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
GraphDef def_1;
test::graph::ToGraphDef(&graph_1, &def_1);
Notification n;
mutex mu;
int succeeded = 0;
int failed = 0;
auto extend_fn = [this, handle, def_1, initial_version, &n, &mu, &succeeded,
&failed]() {
n.WaitForNotification();
int64 new_version;
Status s = ExtendSession(handle, def_1, initial_version, &new_version);
EXPECT_TRUE(s.ok() || errors::IsAborted(s));
{
mutex_lock l(mu);
if (s.ok()) {
++succeeded;
} else {
++failed;
}
}
};
// Run 100 concurrent Extend calls and expect only one to succeed.
{
thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 100);
for (int i = 0; i < 100; ++i) {
thread_pool.Schedule(extend_fn);
}
n.Notify();
}
EXPECT_EQ(failed, 99);
EXPECT_EQ(succeeded, 1);
TF_ASSERT_OK(CloseSession(handle));
}
TEST_F(MasterTest, ConcurrentExtendAndRun) {
Graph graph_0(OpRegistry::Global());
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
test::graph::Constant(&graph_0, a_tensor, "A");
GraphDef def_0;
test::graph::ToGraphDef(&graph_0, &def_0);
string handle;
int64 initial_version;
TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
Graph graph_1(OpRegistry::Global());
Tensor b_tensor(DT_FLOAT, TensorShape({2, 2}));
test::FillValues<float>(&b_tensor, {1, 0, 0, 1});
test::graph::Constant(&graph_1, b_tensor, "B");
GraphDef def_1;
test::graph::ToGraphDef(&graph_1, &def_1);
Notification extend_done;
Notification extend_can_start;
auto get_a_fn = [this, handle, &extend_done]() {
Tensor A(DT_FLOAT, TensorShape({2, 2}));
while (!extend_done.HasBeenNotified()) {
TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
}
// Run at least once after the Extend has completed.
TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
};
auto get_a_and_b_fn = [this, handle, &extend_done, &extend_can_start]() {
Tensor A(DT_FLOAT, TensorShape({2, 2}));
Tensor B(DT_FLOAT, TensorShape({2, 2}));
// Run at least once before the Extend has completed.
EXPECT_TRUE(
errors::IsNotFound(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}})));
extend_can_start.Notify();
// Concurrent with the Extend, we will either fail (as above), or
// succeed (as below).
while (!extend_done.HasBeenNotified()) {
Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}});
EXPECT_TRUE(errors::IsNotFound(s) || s.ok());
}
// Run at least once after the Extend has completed.
TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}));
};
auto extend_fn = [this, handle, def_1, initial_version, &extend_done,
&extend_can_start]() {
extend_can_start.WaitForNotification();
int64 version_1;
TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
extend_done.Notify();
};
{
thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 3);
thread_pool.Schedule(get_a_fn);
thread_pool.Schedule(get_a_and_b_fn);
thread_pool.Schedule(extend_fn);
}
TF_ASSERT_OK(CloseSession(handle));
}
TEST_F(MasterTest, EigenProblem) {
// A = [3 2; -1 0]; x = rand(2, 1);
// for i=1:100; x = A * x; end
// We'll try to compute the largest eigenvalue for A.
Graph graph(OpRegistry::Global());
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
// Store rows [3, 2] and [-1, 0] in row major format.
test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
Node* a_node = test::graph::Constant(&graph, a_tensor);
// x is from the feed.
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
test::FillValues<float>(&x_tensor, {0, 0});
Node* x_node = test::graph::Constant(&graph, x_tensor);
// y = A * x
Node* y_node = test::graph::Matmul(&graph, a_node, x_node, false, false);
GraphDef def;
test::graph::ToGraphDef(&graph, &def);
string handle;
int64 initial_version;
TF_CHECK_OK(CreateSession(def, &handle, &initial_version));
// Temps supporting the computation of the convergence condition.
const Eigen::array<Eigen::DenseIndex, 1> sum_along_dim(0);
const Eigen::array<Eigen::DenseIndex, 2> matrix_transpose({1, 0});
Tensor x(DT_FLOAT, TensorShape({2, 1}));
Tensor y(DT_FLOAT, TensorShape({2, 1}));
Eigen::Tensor<float, 1, Eigen::RowMajor> y_square_sum;
Eigen::Tensor<float, 2, Eigen::RowMajor> y_normalized(2, 1);
y_normalized.setRandom();
Eigen::Tensor<float, 1, Eigen::RowMajor> error_square_sum;
float lambda;
// The computation loop.
bool converged = false;
while (!converged) {
// Run one step of the graph.
auto x_matrix = x.matrix<float>();
x_matrix = y_normalized;
TF_EXPECT_OK(
RunStep(handle, {{x_node->name(), &x}}, {{y_node->name() + ":0", &y}}));
auto y_matrix = y.matrix<float>();
// Client code computes the convergence condition.
{
lambda = y_matrix(0, 0) / x_matrix(0, 0);
y_square_sum = y.matrix<float>().square().sum(sum_along_dim);
const float norm = static_cast<float>(sqrt(y_square_sum(0)));
y_normalized = y_matrix * (1 / norm);
error_square_sum = (x_matrix - y_normalized).square().sum(sum_along_dim);
VLOG(1) << "x = [" << x_matrix.shuffle(matrix_transpose) << "] y = ["
<< y_matrix.shuffle(matrix_transpose) << "] lambda = " << lambda;
converged = sqrt(error_square_sum(0)) < 1e-10;
}
}
EXPECT_NEAR(lambda, 2.0, 0.01);
TF_EXPECT_OK(CloseSession(handle));
}
} // namespace tensorflow

View File

@ -0,0 +1,69 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/process_util.h"
#include <string.h>
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
static thread::ThreadPool* InitComputePool(const SessionOptions& options) {
int32 inter_op_parallelism_threads =
options.config.inter_op_parallelism_threads();
if (inter_op_parallelism_threads == 0) {
// Default to using the number of cores available in the process.
inter_op_parallelism_threads = port::NumSchedulableCPUs();
}
return new thread::ThreadPool(Env::Default(), "Compute",
inter_op_parallelism_threads);
}
} // namespace
thread::ThreadPool* ComputePool(const SessionOptions& options) {
static thread::ThreadPool* compute_pool = InitComputePool(options);
return compute_pool;
}
void SchedClosure(std::function<void()> closure) {
if (port::Tracing::IsActive()) {
const uint64 id = port::Tracing::UniqueId();
port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure,
id);
std::function<void()> wrapper = [closure, id]() {
port::Tracing::ScopedActivity region(
port::Tracing::EventCategory::kRunClosure, id);
closure();
};
Env::Default()->SchedClosure(wrapper);
} else {
Env::Default()->SchedClosure(closure);
}
}
void SchedNonBlockingClosureAfter(int micros, std::function<void()> closure) {
Env::Default()->SchedClosureAfter(micros, closure);
}
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_
#include <functional>
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
// Returns a process-wide ThreadPool for scheduling compute operations
// using 'options'. Caller does not take ownership over threadpool.
thread::ThreadPool* ComputePool(const SessionOptions& options);
// Schedule "closure" in the default thread queue.
void SchedClosure(std::function<void()> closure);
// Schedule "closure" after the given number of microseconds in the
// fixed-size ThreadPool used for non-blocking compute tasks.
void SchedNonBlockingClosureAfter(int micros, std::function<void()> closure);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_

View File

@ -0,0 +1,91 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
using std::placeholders::_1;
// TODO(zhifengc): We need to consolidate (full/partial) device name
// parsing into one place.
//
// Parses and returns the local device part (e.g., cpu:0, gpu:4).
string GetLocalDeviceName(StringPiece fullname) {
auto pos = fullname.rfind('/');
CHECK_NE(pos, StringPiece::npos);
fullname.remove_prefix(pos + 1);
return fullname.ToString();
}
class RemoteDevice : public Device {
public:
RemoteDevice(Env* env, const DeviceAttributes& da, WorkerInterface* wi)
: Device(env, da, nullptr),
local_dev_name_(GetLocalDeviceName(da.name())),
wi_(wi) {}
~RemoteDevice() override { delete wi_; }
Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
private:
const string local_dev_name_;
WorkerInterface* wi_;
TF_DISALLOW_COPY_AND_ASSIGN(RemoteDevice);
};
void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
const string& worker_name, NewRemoteDevicesDone done) {
WorkerInterface* wi = worker_cache->CreateWorker(worker_name);
if (wi == nullptr) {
std::vector<Device*> empty;
done(errors::NotFound("Device ", worker_name, " is not found."), &empty);
return;
}
struct Call {
GetStatusRequest req;
GetStatusResponse resp;
};
Call* call = new Call;
auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) {
std::vector<Device*> remote_devices;
if (s.ok()) {
remote_devices.reserve(call->resp.device_attributes_size());
for (const DeviceAttributes& da : call->resp.device_attributes()) {
auto d =
new RemoteDevice(env, da, worker_cache->CreateWorker(worker_name));
remote_devices.push_back(d);
}
}
done(s, &remote_devices);
delete wi;
delete call;
};
wi->GetStatusAsync(&call->req, &call->resp, cb);
}
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_
#include <functional>
#include <string>
#include <vector>
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class Device;
class Env;
class WorkerCacheInterface;
// NewRemoteDevices discovers available devices on the
// 'remote_worker'. The implementation uses 'channel_cache' to
// discover how to communicate with the 'remote_worker' (via gRPC, for
// example).
//
// NewRemoteDevices does not block.
//
// On success, the 'done' callback is given the OK status and a vector
// of Device*. The caller should take ownership of these devices.
//
// Otherwise, the 'done' callback is given an error status and the
// vector is empty.
typedef std::function<void(const Status&, std::vector<Device*>*)>
NewRemoteDevicesDone;
void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
const string& remote_worker, NewRemoteDevicesDone done);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_

View File

@ -0,0 +1,89 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
const char* const kSession = "remote_session";
class RemoteDeviceTest : public ::testing::Test {
protected:
string remote_name_;
std::unique_ptr<WorkerCacheInterface> worker_cache_;
std::unique_ptr<WorkerInterface> wi_;
std::vector<Device*> devices_;
std::unique_ptr<test::TestCluster> cluster_;
RemoteDeviceTest() {
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 1, &cluster_));
const string& hostport = cluster_->targets()[0];
string host;
int port;
CHECK(RE2::FullMatch(hostport, "(.+):(\\d+)", &host, &port));
GrpcChannelSpec spec;
spec.AddHostPortsJob("localhost", {hostport}, 1);
worker_cache_.reset(NewGrpcWorkerCache(NewGrpcChannelCache(spec)));
remote_name_ = strings::StrCat("/job:", host, "/replica:0/task:0");
wi_.reset(worker_cache_->CreateWorker(remote_name_));
}
void SetUp() override {
Notification n;
NewRemoteDevices(Env::Default(), worker_cache_.get(), remote_name_,
[&n, this](const Status& s, std::vector<Device*>* found) {
TF_CHECK_OK(s);
devices_ = *found;
n.Notify();
});
n.WaitForNotification();
EXPECT_EQ(devices_.size(), 2);
std::sort(devices_.begin(), devices_.end(), [](Device* a, Device* b) {
return a->name().compare(b->name()) < 0;
});
}
void TearDown() override {
for (auto d : devices_) delete d;
}
};
TEST_F(RemoteDeviceTest, GetStatus) {
// We know what the testlib's fake server does.
EXPECT_EQ(devices_[0]->name(), strings::StrCat(remote_name_, "/cpu:0"));
EXPECT_EQ(devices_[0]->attributes().device_type(),
DeviceType(DEVICE_CPU).type());
EXPECT_EQ(devices_[0]->attributes().memory_limit(), 256 << 20);
EXPECT_EQ(devices_[1]->name(), strings::StrCat(remote_name_, "/cpu:1"));
EXPECT_EQ(devices_[1]->attributes().memory_limit(), 256 << 20);
}
} // namespace tensorflow

View File

@ -0,0 +1,79 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_
#include <string>
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id"
// corresponds to one local rendezvous instance managed by a
// RendezvousMgr.
//
// E.g.,
// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
// fork execution of an graph executor using "rendez" on thread 1;
// fork execution of another graph executor using "rendez" on thread 2;
// ...
// join threads 1 and 2;
//
// In the example above, execution in thread 1 and 2 communicates with
// each other by send/recv operations through the "rend".
//
// Tensors sent and recved through rendezvous managed by this
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
class RendezvousMgrInterface {
public:
RendezvousMgrInterface() {}
virtual ~RendezvousMgrInterface() {}
// Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance.
virtual Rendezvous* Find(int64 step_id) = 0;
// Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs.
//
// This method is used by the rpc handler of RecvTensor.
virtual void RecvLocalAsync(int64 step_id, const string& key,
Rendezvous::DoneCallback done) = 0;
// Synchronous wrapper for RecvLocalAsync.
virtual Status RecvLocal(int64 step_id, const string& key, Tensor* val,
bool* is_dead) = 0;
// Removes rendezvous for "step_id".
//
// TODO(zhifengc): Have a background thread in worker that
// periodically calls CleanupAll().
virtual void Cleanup(int64 step_id) = 0;
// Removes all rendezvous.
virtual void CleanupAll() = 0;
};
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_

View File

@ -0,0 +1,341 @@
# Description:
# RPC communication interfaces and implementations for TensorFlow.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
"**/*.h",
]),
)
load(
"//tensorflow:tensorflow.bzl",
"tf_cuda_library",
"tf_cc_tests",
)
# For platform specific build config
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_kernel_tests_linkstatic",
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
)
package(default_visibility = [
"//tensorflow:internal",
])
cc_library(
name = "grpc_util",
srcs = [],
hdrs = ["grpc_util.h"],
deps = [
"@grpc//:grpc++_unsecure",
"//tensorflow/core:lib",
],
)
cc_library(
name = "grpc_client_cq_tag",
srcs = [],
hdrs = ["grpc_client_cq_tag.h"],
deps = [
"@grpc//:grpc++_unsecure",
":grpc_util",
"//tensorflow/core:lib",
],
)
cc_library(
name = "grpc_remote_worker",
srcs = ["grpc_remote_worker.cc"],
hdrs = ["grpc_remote_worker.h"],
deps = [
"@grpc//:grpc++_unsecure",
":grpc_client_cq_tag",
":grpc_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core:worker_service_proto_cc",
"//tensorflow/core/distributed_runtime:process_util",
"//tensorflow/core/distributed_runtime:worker_cache_logger",
"//tensorflow/core/distributed_runtime:worker_interface",
],
)
cc_library(
name = "grpc_channel",
srcs = ["grpc_channel.cc"],
hdrs = ["grpc_channel.h"],
deps = [
"@grpc//:grpc++_unsecure",
":grpc_util",
"//tensorflow/core:lib",
],
)
cc_library(
name = "grpc_call",
srcs = [],
hdrs = ["grpc_call.h"],
deps = [
"@grpc//:grpc++_unsecure",
"//tensorflow/core:lib",
],
)
cc_library(
name = "async_service_interface",
srcs = [],
hdrs = ["async_service_interface.h"],
deps = [],
)
cc_library(
name = "grpc_worker_cache",
srcs = ["grpc_worker_cache.cc"],
hdrs = ["grpc_worker_cache.h"],
deps = [
":grpc_channel",
":grpc_client_cq_tag",
":grpc_remote_worker",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_cache_logger",
"//tensorflow/core/distributed_runtime:worker_cache_partial",
"//tensorflow/core/distributed_runtime:worker_interface",
],
)
cc_library(
name = "grpc_worker_service",
srcs = ["grpc_worker_service.cc"],
hdrs = ["grpc_worker_service.h"],
deps = [
"@grpc//:grpc++_unsecure",
":async_service_interface",
":grpc_call",
":grpc_util",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core:worker_service_proto_cc",
"//tensorflow/core/distributed_runtime:graph_mgr",
"//tensorflow/core/distributed_runtime:process_util",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_interface",
],
)
cc_library(
name = "grpc_remote_master",
srcs = ["grpc_remote_master.cc"],
hdrs = ["grpc_remote_master.h"],
deps = [
"@grpc//:grpc++_unsecure",
":grpc_util",
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:master_service_proto_cc",
"//tensorflow/core/distributed_runtime:master_interface",
],
alwayslink = 1,
)
cc_library(
name = "grpc_master_service",
srcs = ["grpc_master_service.cc"],
hdrs = ["grpc_master_service.h"],
deps = [
"@grpc//:grpc++_unsecure",
":async_service_interface",
":grpc_call",
":grpc_util",
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:master_service_proto_cc",
"//tensorflow/core/distributed_runtime:master",
"//tensorflow/core/distributed_runtime:master_interface",
],
alwayslink = 1,
)
cc_library(
name = "rpc_rendezvous_mgr",
srcs = ["rpc_rendezvous_mgr.cc"],
hdrs = ["rpc_rendezvous_mgr.h"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
"//tensorflow/core/distributed_runtime:process_util",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_interface",
],
)
cc_library(
name = "grpc_server_lib",
srcs = [
"grpc_server_lib.cc",
],
hdrs = ["grpc_server_lib.h"],
deps = [
"@grpc//:grpc++_unsecure",
":async_service_interface",
":grpc_channel",
":grpc_master_service",
":grpc_worker_cache",
":grpc_worker_service",
":rpc_rendezvous_mgr",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime:graph_mgr",
"//tensorflow/core/distributed_runtime:master_env",
"//tensorflow/core/distributed_runtime:master_session",
"//tensorflow/core/distributed_runtime:process_util",
"//tensorflow/core/distributed_runtime:worker_env",
],
)
cc_binary(
name = "grpc_tensorflow_server",
srcs = [
"grpc_tensorflow_server.cc",
],
deps = [
"@grpc//:grpc++_unsecure",
":grpc_server_lib",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
tf_cuda_library(
name = "grpc_testlib_ops",
testonly = 1,
srcs = ["grpc_testlib_ops.cc"],
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
deps = [
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_binary(
name = "grpc_testlib_server",
testonly = 1,
srcs = [
"grpc_testlib_server.cc",
],
deps = [
"@grpc//:grpc++_unsecure",
":grpc_server_lib",
":grpc_testlib_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
tf_cuda_library(
name = "grpc_testlib",
testonly = 1,
srcs = ["grpc_testlib.cc"],
hdrs = ["grpc_testlib.h"],
data = [
":grpc_testlib_server",
],
deps = [
":grpc_session",
":grpc_testlib_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core:test",
],
alwayslink = 1,
)
cc_library(
name = "grpc_session",
srcs = ["grpc_session.cc"],
hdrs = ["grpc_session.h"],
deps = [
":grpc_channel",
":grpc_remote_master",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core/distributed_runtime:master_interface",
],
alwayslink = 1,
)
tf_cc_tests(
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
tests = [
"grpc_channel_test.cc",
"grpc_session_test.cc",
"rpc_rendezvous_mgr_test.cc",
],
deps = [
":grpc_channel",
":grpc_session",
":grpc_testlib",
":rpc_rendezvous_mgr",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime:process_util",
],
)

View File

@ -0,0 +1,37 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
namespace tensorflow {
// Represents an abstract asynchronous service that handles incoming
// RPCs with a polling loop.
class AsyncServiceInterface {
public:
virtual ~AsyncServiceInterface() {}
// A blocking method that should be called to handle incoming RPCs.
// This method will block until the service is shutdown, which
// depends on the implementation of the service.
virtual void HandleRPCsLoop() = 0;
// TODO(mrry): Add a clean shutdown method?
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_

View File

@ -0,0 +1,227 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
#include "tensorflow/core/platform/macros.h"
#include "external/grpc/include/grpc++/grpc++.h"
#include "external/grpc/include/grpc++/server_builder.h"
namespace tensorflow {
// CALL STRUCTURES
// ===============
//
// Each pending (incoming) request corresponds to a call object that
// encapsulates the state of the call. Templates and
// pointers-to-member functions are used to avoid boilerplate and
// redundant closure creation. The class hierarchy is as follows:
//
// * `UntypedCall<Service>`: The base class represents a call that
// could be associated with any of the methods on a service of type
// `Service`. Also defines a `Tag` nested class that can be used as
// the tag in a `grpc::CompletionQueue`. Each class that
// instantiates `Service` should have a completion queue polling
// loop that knows about `UntypedCall<Service>::Tag` objects, and
// invokes their `OnCompleted()` method to continue processing.
//
// * `Call<Service, GrpcService, Req, Resp>`: This class extends
// `UntypedCall<Service>` and is additionally parameterized by the
// gRPC-generated asynchronous service class, and the request and
// response message types. It defines the state associated with a
// call (whose type depends on the message types), and stores a
// pointer to a `Service::HandleFoo()` handler method. Each
// `Service::HandleFoo()` method knows about the corresponding
// `Call` type, in order to access its state, and invoke its
// `SendResponse()` method.
//
// The lifecycle of a call object is as follows.
//
// 1. A `Service` creates a `Call` for a particular method and
// enqueues it in its completion queue (via an
// `UntypedCall<Service>::Tag`).
//
// 2. When the tag is returned from `cq_->Next()`, the
// `UntypedCall::RequestReceived()` method is invoked and takes
// ownership of the call object. This indirectly invokes the
// appropriate handler method on `Service`.
//
// 3. After the response has been written (perhaps in another thread),
// the `Call::SendResponse()` method is invoked. It transfers
// ownership of the call object back to the completion queue (via
// an `UntypedCall::Tag`).
//
// 4. When the response has been sent, the tag is returned from
// `cq_->Next()`, and the call object is deleted.
// Represents a pending request with unknown message types.
template <class Service>
class UntypedCall : public core::RefCounted {
public:
virtual ~UntypedCall() {}
// The implementation of this method should use `service` to handle
// an incoming request, and (perhaps asynchronously) send the
// response.
//
// One reference on `this` is transferred to the callee, and the
// callee is responsible for releasing it (typically via
// `Call::SendResponse()`).
//
// `ok` is true if the request was received in a "regular event",
// otherwise false.
virtual void RequestReceived(Service* service, bool ok) = 0;
// This method will be called when the response has been sent by
// `service` and the call is no longer used.
//
// `ok` is true if the response sending completed as a "regular
// event", otherwise it is false.
void ResponseSent(Service* service, bool ok) {}
// This method will be called either (i) when the server is notified
// that the request has been cancelled, or (ii) when the request completes
// normally. The implementation should distinguish these cases by querying
// the `grpc::ServerContext` associated with the request.
virtual void RequestCancelled(Service* service, bool ok) = 0;
// Associates a tag in a `::grpc::CompletionQueue` with a callback
// for an incoming RPC. A Tag owns a reference on the corresponding
// Call object.
class Tag {
public:
using Callback = void (UntypedCall::*)(Service*, bool);
// Creates a new `Tag` for the given `UntypedCall`. When the
// request associated with this tag is complete, `callback` will
// be called.
Tag(UntypedCall* call, Callback callback)
: call_(call), callback_(callback) {
call_->Ref();
}
~Tag() { call_->Unref(); }
// Calls the callback associated with this tag.
//
// The callback takes ownership of `this->call_`.
void OnCompleted(Service* service, bool ok) {
(call_->*callback_)(service, ok);
}
private:
UntypedCall* call_; // `this` owns one reference.
Callback callback_;
};
};
// Represents a pending call with known request and response message
// types, and a known request-handling method.
template <class Service, class GrpcService, class RequestMessage,
class ResponseMessage>
class Call : public UntypedCall<Service> {
public:
// Represents the generic signature of a generated
// `GrpcService::RequestFoo()` method, where `Foo` is the name of an
// RPC method.
using EnqueueFunction = void (GrpcService::*)(
::grpc::ServerContext*, RequestMessage*,
::grpc::ServerAsyncResponseWriter<ResponseMessage>*,
::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*);
// Represents the generic signature of a `Service::HandleFoo()`
// method, where `Foo` is the name of an RPC method.
using HandleRequestFunction = void (Service::*)(
Call<Service, GrpcService, RequestMessage, ResponseMessage>*);
Call(HandleRequestFunction handle_request_function)
: handle_request_function_(handle_request_function), responder_(&ctx_) {}
virtual ~Call() {}
void RequestReceived(Service* service, bool ok) override {
if (ok) {
this->Ref();
(service->*handle_request_function_)(this);
}
}
void SendResponse(::grpc::Status status) {
responder_.Finish(response, status,
new typename UntypedCall<Service>::Tag(
this, &UntypedCall<Service>::ResponseSent));
this->Unref();
}
void RequestCancelled(Service* service, bool ok) override {
if (ctx_.IsCancelled()) {
mutex_lock l(mu_);
if (cancel_callback_) {
cancel_callback_();
}
}
}
// Registers `callback` as the function that should be called if and when this
// call is cancelled by the client.
void SetCancelCallback(std::function<void()> callback) {
mutex_lock l(mu_);
cancel_callback_ = callback;
}
// Clears any cancellation callback that has been registered for this call.
void ClearCancelCallback() {
mutex_lock l(mu_);
cancel_callback_ = nullptr;
}
// Enqueues a new request for the given service on the given
// completion queue, using the given `enqueue_function`.
//
// The request will be handled with the given
// `handle_request_function`.
static void EnqueueRequest(GrpcService* grpc_service,
::grpc::ServerCompletionQueue* cq,
EnqueueFunction enqueue_function,
HandleRequestFunction handle_request_function) {
auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
handle_request_function);
call->ctx_.AsyncNotifyWhenDone(new typename UntypedCall<Service>::Tag(
call, &UntypedCall<Service>::RequestCancelled));
(grpc_service->*enqueue_function)(
&call->ctx_, &call->request, &call->responder_, cq, cq,
new typename UntypedCall<Service>::Tag(
call, &UntypedCall<Service>::RequestReceived));
call->Unref();
}
RequestMessage request;
ResponseMessage response;
private:
HandleRequestFunction handle_request_function_;
::grpc::ServerContext ctx_;
::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
mutex mu_;
std::function<void()> cancel_callback_ GUARDED_BY(mu_);
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_

View File

@ -0,0 +1,314 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include <unordered_map>
#include "external/grpc/include/grpc++/create_channel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
RE2* kTargetRE = new RE2("^/job:([^/]+)/replica:([0-9]+)/task:([0-9]+)$");
RE2* kHostPortRE = new RE2("([^:/]+):(\\d+)");
RE2* kSparseHostPortRE = new RE2("(\\d+):([^:/]+):(\\d+)");
string MakeAddress(const string& job, int replica, int task) {
return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task);
}
} // namespace
SharedGrpcChannelPtr NewHostPortGrpcChannel(const string& target) {
// TODO(mrry): Implement secure channels.
return ::grpc::CreateChannel(target, ::grpc::InsecureChannelCredentials());
}
Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
const std::vector<string>& host_ports,
int tasks_per_replica) {
if (!job_ids_.insert(job_id).second) {
return errors::InvalidArgument(
"Duplicate job ID in cluster specification: ", job_id);
}
HostPortsJob job;
job.job_id = job_id;
for (const string& host_port : host_ports) {
string host;
int port;
if (!RE2::FullMatch(host_port, *kHostPortRE, &host, &port)) {
return errors::InvalidArgument("Could not interpret \"", host_port,
"\" as a host-port pair.");
}
}
job.host_ports = host_ports;
job.tasks_per_replica = tasks_per_replica;
host_ports_jobs_.push_back(job);
return Status::OK();
}
GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec) {
const int num_jobs = spec.host_ports_jobs().size();
if (!num_jobs) {
LOG(ERROR) << "Empty channel spec.";
return nullptr;
}
std::vector<GrpcChannelCache*> caches;
caches.reserve(num_jobs);
for (const GrpcChannelSpec::HostPortsJob& job : spec.host_ports_jobs()) {
caches.push_back(NewHostPortsGrpcChannelCache(job.job_id, job.host_ports,
job.tasks_per_replica));
}
return caches.size() == 1 ? caches[0] : NewMultiGrpcChannelCache(caches);
}
// GrpcChannelCache that caches results to FindWorkerChannel() calls.
class CachingGrpcChannelCache : public GrpcChannelCache {
public:
CachingGrpcChannelCache() {}
~CachingGrpcChannelCache() override {}
SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
SharedGrpcChannelPtr ch = nullptr;
{
mutex_lock l(mu_); // could use reader lock
ch = gtl::FindPtrOrNull(channels_, target);
if (ch) {
return ch;
}
}
ch = FindChannelOnce(target);
if (ch) {
mutex_lock l(mu_);
channels_.insert({target, ch});
}
return ch;
}
protected:
// Find the ClientChannel for "target". Only called when no channel was
// found in the channels_ cache for "target". A non nullptr result will be
// cached in channels_.
virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
private:
// TODO(zhifengc): Eviction when the map becomes too big.
mutex mu_;
std::unordered_map<string, SharedGrpcChannelPtr> channels_ GUARDED_BY(mu_);
};
// A ChannelCache that is the union of multiple ChannelCaches.
// Takes ownership of the caches passed to the constructor.
class MultiGrpcChannelCache : public CachingGrpcChannelCache {
public:
explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches)
: CachingGrpcChannelCache(), caches_(caches) {}
~MultiGrpcChannelCache() override {
for (GrpcChannelCache* cache : caches_) {
delete cache;
}
}
void ListWorkers(std::vector<string>* workers) override {
for (GrpcChannelCache* cache : caches_) {
cache->ListWorkers(workers);
}
}
string TranslateTask(const string& target) override {
mutex_lock l(mu_); // could use reader lock
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
if (cache == nullptr) {
for (GrpcChannelCache* c : caches_) {
string r = c->TranslateTask(target);
if (!r.empty()) {
target_caches_.insert({target, c});
cache = c;
break;
}
}
}
CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
<< target;
return cache->TranslateTask(target);
}
protected:
SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
for (GrpcChannelCache* cache : caches_) {
SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
if (ch) {
mutex_lock l(mu_);
target_caches_.insert({target, cache});
return ch;
}
}
return nullptr;
}
private:
// List of channels used by this MultiGrpcChannelCache.
const std::vector<GrpcChannelCache*> caches_;
mutex mu_;
// Cache of channels keyed by the target they are handling.
// The same GrpcChannelCache can appear multiple times in the cache.
std::unordered_map<string, GrpcChannelCache*> target_caches_ GUARDED_BY(mu_);
};
GrpcChannelCache* NewMultiGrpcChannelCache(
const std::vector<GrpcChannelCache*>& caches) {
return new MultiGrpcChannelCache(caches);
}
class HostPortsGrpcChannelCache : public CachingGrpcChannelCache {
public:
HostPortsGrpcChannelCache(const string& job_id,
const std::vector<string>& host_ports,
int tasks_per_replica)
: job_id_(job_id),
host_ports_(BuildDenseHostPortsList(host_ports, tasks_per_replica)),
tasks_per_replica_(tasks_per_replica) {
LOG(INFO) << "Initialize HostPortsGrpcChannelCache for job " << job_id
<< " -> {" << str_util::Join(host_ports, ", ") << "}";
}
~HostPortsGrpcChannelCache() override {}
void ListWorkers(std::vector<string>* workers) override {
int num_host_ports = 0;
for (size_t i = 0; i < host_ports_.size(); ++i) {
if (!host_ports_[i].empty()) {
++num_host_ports;
}
}
workers->reserve(workers->size() + num_host_ports);
for (size_t i = 0; i < host_ports_.size(); ++i) {
if (!host_ports_[i].empty()) {
workers->emplace_back(MakeAddress(job_id_, i / tasks_per_replica_,
i % tasks_per_replica_));
}
}
}
string TranslateTask(const string& target) override {
RegexpStringPiece job;
int32 replica;
int32 task;
if (!RE2::FullMatch(target, *kTargetRE, &job, &replica, &task)) {
LOG(WARNING) << "Invalid target: " << target;
return "";
}
if (job != job_id_) {
return "";
}
if (task >= tasks_per_replica_) {
LOG(WARNING) << "Task out of bounds for job " << job_id_ << ": " << task;
return "";
}
const size_t i = replica * tasks_per_replica_ + task;
if (i >= host_ports_.size()) {
LOG(WARNING) << "Replica/task out of bounds for job " << job_id_ << ": "
<< target;
return "";
}
if (host_ports_[i].empty()) {
LOG(WARNING) << "Replica/task not in sparse index:host:port list for job "
<< job_id_ << ": " << target;
return "";
}
return host_ports_[i];
}
protected:
static std::vector<string> BuildDenseHostPortsList(
const std::vector<string>& host_ports, int tasks_per_replica) {
std::map<int, string> sparse_host_ports;
for (const string& host_port : host_ports) {
int i = -1;
string host;
int port = -1;
if (RE2::FullMatch(host_port, *kSparseHostPortRE, &i, &host, &port)) {
CHECK_LE(0, i);
CHECK_LE(0, port);
CHECK(sparse_host_ports.find(i) == sparse_host_ports.end())
<< "Duplicate index " << i << ": {"
<< str_util::Join(host_ports, ", ") << "}";
sparse_host_ports[i] = strings::StrCat(host, ":", port);
} else {
CHECK(RE2::FullMatch(host_port, *kHostPortRE, &host, &port))
<< host_port
<< " does not look like a host:port or an index:host:port";
}
}
if (sparse_host_ports.empty()) {
// The input is a dense list; return it directly.
return host_ports;
}
// The input is a sparse list. Convert it to a dense list.
CHECK_EQ(host_ports.size(), sparse_host_ports.size())
<< "Mix of host:port and index:host:port: {"
<< str_util::Join(host_ports, ", ") << "}";
int num_tasks = sparse_host_ports.rbegin()->first + 1;
if (num_tasks % tasks_per_replica != 0) {
num_tasks = (num_tasks / tasks_per_replica + 1) * tasks_per_replica;
}
std::vector<string> dense_host_ports;
dense_host_ports.resize(num_tasks);
for (const auto& p : sparse_host_ports) {
dense_host_ports[p.first] = p.second;
}
return dense_host_ports;
}
SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
const string host_port = TranslateTask(target);
if (host_port.empty()) {
LOG(WARNING) << "Could not find channel for target: " << target;
return nullptr;
}
return NewHostPortGrpcChannel(host_port);
}
private:
const string job_id_;
const std::vector<string> host_ports_;
const int tasks_per_replica_;
TF_DISALLOW_COPY_AND_ASSIGN(HostPortsGrpcChannelCache);
};
GrpcChannelCache* NewHostPortsGrpcChannelCache(
const string& job_id, const std::vector<string>& host_ports,
int tasks_per_replica) {
return new HostPortsGrpcChannelCache(job_id, host_ports, tasks_per_replica);
}
} // end namespace tensorflow

View File

@ -0,0 +1,98 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "external/grpc/include/grpc++/grpc++.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
namespace tensorflow {
// Consolidated parameter structure to ease use of generic interfaces.
//
// Each job_id requires:
// - a list of host:port (or sparse list of index:host:port)
// - the number of tasks per replica
class GrpcChannelSpec {
public:
struct HostPortsJob {
string job_id;
std::vector<string> host_ports;
int tasks_per_replica;
};
Status AddHostPortsJob(const string& job_id,
const std::vector<string>& host_ports,
int tasks_per_replica);
const std::vector<HostPortsJob>& host_ports_jobs() const {
return host_ports_jobs_;
}
private:
std::vector<HostPortsJob> host_ports_jobs_;
std::set<string> job_ids_;
};
class GrpcChannelCache {
public:
virtual ~GrpcChannelCache() {}
// Populates *workers with names of all workers which this object
// was created to handle. Worker names are in the format
// /job:<job identifier>/task:<task id>
// e.g. /job:mnist/task:2
virtual void ListWorkers(std::vector<string>* workers) = 0;
// If found, returns a gRPC channel that is connected to the remote
// worker named by 'target'. 'target' is of the following
// format: /job:<job identifier>/task:<task id>
// E.g., /job:mnist/task:2
virtual SharedGrpcChannelPtr FindWorkerChannel(const string& target) = 0;
// Translates a string in the form `/job:X/task:Z` into a host_port.
virtual string TranslateTask(const string& task) = 0;
};
GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& p);
// Below here are internal-only functions.
SharedGrpcChannelPtr NewHostPortGrpcChannel(const string& target);
// Returns a ChannelCache that uses a set of known host:port pairs. E.g., say,
// job_id = 'mnist', 'host_ports' = {"h0:0", "h1:1", ..., "h11:11", "h12:12"},
// tasks_per_replica = 8, /job:mnist/replica:1/task:3 is mapped to host:port
// "h11:11" (11 = 8 * 1 + 3).
//
// The caller takes ownership of the returned object.
GrpcChannelCache* NewHostPortsGrpcChannelCache(
const string& job_id, const std::vector<string>& host_ports,
int tasks_per_replica);
// Returns a ChannelCache that is the union of a number of other ChannelCaches.
GrpcChannelCache* NewMultiGrpcChannelCache(
const std::vector<GrpcChannelCache*>& caches);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_

View File

@ -0,0 +1,137 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include <string>
#include <vector>
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
#define IsSameAddrSp DeviceNameUtils::IsSameAddressSpace
TEST(GrpcChannelTest, IsSameAddressSpace) {
// Same.
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0",
"/job:mnist/replica:10/task:10/cpu:1"));
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0",
"/job:mnist/replica:10/task:10/gpu:2"));
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10",
"/job:mnist/replica:10/task:10/gpu:2"));
EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:1",
"/job:mnist/replica:10/task:10"));
// Different.
EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:10/task:9/cpu:0",
"/job:mnist/replica:10/task:10/cpu:0"));
EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:9/task:10/cpu:0",
"/job:mnist/replica:10/task:10/cpu:0"));
EXPECT_FALSE(IsSameAddrSp("/job:MNIST/replica:10/task:10/cpu:0",
"/job:mnist/replica:10/task:10/cpu:0"));
// Invalid names.
EXPECT_FALSE(IsSameAddrSp("random_invalid_target", "random_invalid_target"));
EXPECT_FALSE(IsSameAddrSp("/job:/replica:10/task:10/cpu:0",
"/job:/replica:10/task:10/cpu:1"));
EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:xx/task:10/cpu:0",
"/job:mnist/replica:xx/task:10/cpu:1"));
EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:10/task:yy/cpu:0",
"/job:mnist/replica:10/task:yy/cpu:1"));
}
TEST(GrpcChannelTest, HostPorts) {
std::unique_ptr<GrpcChannelCache> cc(NewHostPortsGrpcChannelCache(
"mnist", {"a:1", "b:2", "c:3", "d:4", "e:5", "f:6"}, 2));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("invalid_target"));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:other/replica:0/task:0"));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:2"));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:3/task:0"));
{
// NOTE(mrry): The gRPC channel doesn't expose the target, so we
// can't compare it for equality.
auto a_1_1 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
auto a_1_2 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
auto d_4_1 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
auto d_4_2 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
auto e_5_1 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
auto e_5_2 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
EXPECT_EQ(a_1_1.get(), a_1_2.get());
EXPECT_EQ(d_4_1.get(), d_4_2.get());
EXPECT_EQ(e_5_1.get(), e_5_2.get());
EXPECT_NE(a_1_1.get(), d_4_2.get());
EXPECT_NE(a_1_1.get(), e_5_2.get());
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
std::vector<string> workers;
cc->ListWorkers(&workers);
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
"/job:mnist/replica:0/task:1",
"/job:mnist/replica:1/task:0",
"/job:mnist/replica:1/task:1",
"/job:mnist/replica:2/task:0",
"/job:mnist/replica:2/task:1"}),
workers);
}
TEST(GrpcChannelTest, SparseHostPorts) {
std::unique_ptr<GrpcChannelCache> cc(
NewHostPortsGrpcChannelCache("mnist", {"0:a:1", "3:d:4", "4:e:5"}, 2));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("invalid_target"));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:other/replica:0/task:0"));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:1"));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:2"));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:1/task:0"));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:2/task:1"));
EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:3/task:0"));
{
// NOTE(mrry): The gRPC channel doesn't expose the target, so we
// can't compare it for equality.
auto a_1_1 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
auto a_1_2 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
auto d_4_1 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
auto d_4_2 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
auto e_5_1 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
auto e_5_2 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
EXPECT_EQ(a_1_1.get(), a_1_2.get());
EXPECT_EQ(d_4_1.get(), d_4_2.get());
EXPECT_EQ(e_5_1.get(), e_5_2.get());
EXPECT_NE(a_1_1.get(), d_4_2.get());
EXPECT_NE(a_1_1.get(), e_5_2.get());
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
std::vector<string> workers;
cc->ListWorkers(&workers);
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
"/job:mnist/replica:1/task:1",
"/job:mnist/replica:2/task:0"}),
workers);
}
} // namespace tensorflow

View File

@ -0,0 +1,56 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
#include "external/grpc/include/grpc++/grpc++.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// Represents a pending asynchronous client call as a tag that can be
// stored in a `grpc::CompletionQueue`.
class GrpcClientCQTag {
public:
GrpcClientCQTag(::grpc::ClientContext* context, StatusCallback cb)
: context_(context), cb_(cb) {}
~GrpcClientCQTag() { delete context_; }
void OnCompleted(bool ok) {
if (!ok) {
VLOG(2) << "Call returned with non-ok status: "
<< status_.error_message();
}
cb_(FromGrpcStatus(status_));
}
::grpc::ClientContext* context() { return context_; }
::grpc::Status* status() { return &status_; }
private:
::grpc::ClientContext* context_;
::grpc::Status status_;
StatusCallback cb_;
TF_DISALLOW_COPY_AND_ASSIGN(GrpcClientCQTag);
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_

View File

@ -0,0 +1,181 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// GrpcMasterService implements the RPC service MasterSerivce.
//
// A GrpcMasterService maintains the state of live graph computation
// sessions, each session orchestrates both local and remote devices
// to carry out the graph computation.
//
// A GrpcMasterService knows ahead of time local devices available as
// client devices.
//
// A GrpcMasterService discovers remote devices in the background and
// keeps track of statistics of those remote devices.
//
// Each session analyses the graph, places nodes across available
// devices, and ultimately drives the graph computation by initiating
// RunGraph on workers.
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
#include "external/grpc/include/grpc++/server_builder.h"
#include "tensorflow/core/distributed_runtime/master.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
class GrpcMasterService : public AsyncServiceInterface {
public:
GrpcMasterService(MasterEnv* env, ::grpc::ServerBuilder* builder)
: master_impl_(new Master(env, 0.0)) {
builder->RegisterService(&master_service_);
cq_ = builder->AddCompletionQueue().release();
}
~GrpcMasterService() {
delete cq_;
delete master_impl_;
}
// This macro creates a new request for the given RPC method name
// (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on
// `this->cq_`.
//
// This macro is invoked one or more times for each RPC method to
// ensure that there are sufficient completion queue entries to
// handle incoming requests without blocking.
//
// The implementation of the request handler for each RPC method
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
// to keep accepting new requests.
#define ENQUEUE_REQUEST(method) \
do { \
Call<GrpcMasterService, grpc::MasterService::AsyncService, \
method##Request, method##Response>:: \
EnqueueRequest(&master_service_, cq_, \
&grpc::MasterService::AsyncService::Request##method, \
&GrpcMasterService::method##Handler); \
} while (0)
void HandleRPCsLoop() {
ENQUEUE_REQUEST(CreateSession);
ENQUEUE_REQUEST(ExtendSession);
for (int i = 0; i < 100; ++i) {
ENQUEUE_REQUEST(RunStep);
}
ENQUEUE_REQUEST(CloseSession);
ENQUEUE_REQUEST(ListDevices);
ENQUEUE_REQUEST(Reset);
void* tag;
bool ok;
while (cq_->Next(&tag, &ok)) {
CHECK(ok);
UntypedCall<GrpcMasterService>::Tag* callback_tag =
static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag);
callback_tag->OnCompleted(this, ok);
delete callback_tag;
}
}
private:
Master* master_impl_; // Owned.
::grpc::ServerCompletionQueue* cq_; // Owned.
grpc::MasterService::AsyncService master_service_;
template <class RequestMessage, class ResponseMessage>
using MasterCall = Call<GrpcMasterService, grpc::MasterService::AsyncService,
RequestMessage, ResponseMessage>;
// RPC handler for creating a session.
void CreateSessionHandler(
MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
master_impl_->CreateSession(&call->request, &call->response,
[call](const Status& status) {
call->SendResponse(ToGrpcStatus(status));
});
ENQUEUE_REQUEST(CreateSession);
}
// RPC handler for extending a session.
void ExtendSessionHandler(
MasterCall<ExtendSessionRequest, ExtendSessionResponse>* call) {
master_impl_->ExtendSession(&call->request, &call->response,
[call](const Status& status) {
call->SendResponse(ToGrpcStatus(status));
});
ENQUEUE_REQUEST(ExtendSession);
}
// RPC handler for running one step in a session.
void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
CallOptions* call_opts = new CallOptions;
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
master_impl_->RunStep(call_opts, &call->request, &call->response,
[call, call_opts](const Status& status) {
call->ClearCancelCallback();
delete call_opts;
call->SendResponse(ToGrpcStatus(status));
});
ENQUEUE_REQUEST(RunStep);
}
// RPC handler for deleting a session.
void CloseSessionHandler(
MasterCall<CloseSessionRequest, CloseSessionResponse>* call) {
master_impl_->CloseSession(&call->request, &call->response,
[call](const Status& status) {
call->SendResponse(ToGrpcStatus(status));
});
ENQUEUE_REQUEST(CloseSession);
}
// RPC handler for listing devices.
void ListDevicesHandler(
MasterCall<ListDevicesRequest, ListDevicesResponse>* call) {
master_impl_->ListDevices(&call->request, &call->response,
[call](const Status& status) {
call->SendResponse(ToGrpcStatus(status));
});
ENQUEUE_REQUEST(ListDevices);
}
// RPC handler for resetting all sessions.
void ResetHandler(MasterCall<ResetRequest, ResetResponse>* call) {
master_impl_->Reset(&call->request, &call->response,
[call](const Status& status) {
call->SendResponse(ToGrpcStatus(status));
});
ENQUEUE_REQUEST(Reset);
}
#undef ENQUEUE_REQUEST
TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
};
AsyncServiceInterface* NewGrpcMasterService(MasterEnv* env,
::grpc::ServerBuilder* builder) {
CHECK(!env->local_devices.empty());
return new GrpcMasterService(env, builder);
}
} // end namespace tensorflow

View File

@ -0,0 +1,33 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
namespace grpc {
class ServerBuilder;
} // namespace grpc
namespace tensorflow {
class AsyncServiceInterface;
class MasterEnv;
AsyncServiceInterface* NewGrpcMasterService(MasterEnv* env,
::grpc::ServerBuilder* builder);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_

View File

@ -0,0 +1,79 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
#include "tensorflow/core/distributed_runtime/master_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
// GrpcRemoteMaster is an implementation of the MasterInterface
// that uses gRPC to talk to the Master service.
class GrpcRemoteMaster : public MasterInterface {
public:
explicit GrpcRemoteMaster(SharedGrpcChannelPtr client_channel)
: stub_(grpc::MasterService::NewStub(client_channel)) {}
~GrpcRemoteMaster() override {}
Status CreateSession(const CreateSessionRequest* request,
CreateSessionResponse* response) override {
::grpc::ClientContext ctx;
return FromGrpcStatus(stub_->CreateSession(&ctx, *request, response));
}
Status ExtendSession(const ExtendSessionRequest* request,
ExtendSessionResponse* response) override {
::grpc::ClientContext ctx;
return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response));
}
Status RunStep(const RunStepRequest* request,
RunStepResponse* response) override {
::grpc::ClientContext ctx;
return FromGrpcStatus(stub_->RunStep(&ctx, *request, response));
}
Status CloseSession(const CloseSessionRequest* request,
CloseSessionResponse* response) override {
::grpc::ClientContext ctx;
return FromGrpcStatus(stub_->CloseSession(&ctx, *request, response));
}
Status ListDevices(const ListDevicesRequest* request,
ListDevicesResponse* response) override {
::grpc::ClientContext ctx;
return FromGrpcStatus(stub_->ListDevices(&ctx, *request, response));
}
Status Reset(const ResetRequest* request, ResetResponse* response) override {
::grpc::ClientContext ctx;
return FromGrpcStatus(stub_->Reset(&ctx, *request, response));
}
private:
std::unique_ptr<grpc::MasterService::Stub> stub_;
};
MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel) {
return new GrpcRemoteMaster(channel);
}
} // namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
#include "tensorflow/core/distributed_runtime/master_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
namespace tensorflow {
// Returns a MasterInterface wrapped around the gRPC channel `channel`.
MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_

View File

@ -0,0 +1,203 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
#include "external/grpc/include/grpc++/grpc++.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/protobuf/worker_service.grpc.pb.h"
namespace tensorflow {
class GrpcRemoteWorker : public WorkerInterface {
public:
explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue,
WorkerCacheLogger* logger)
: stub_(grpc::WorkerService::NewStub(channel)),
cq_(completion_queue),
logger_(logger) {}
~GrpcRemoteWorker() override {}
void GetStatusAsync(const GetStatusRequest* request,
GetStatusResponse* response,
StatusCallback done) override {
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncGetStatus,
done);
}
void RegisterGraphAsync(const RegisterGraphRequest* request,
RegisterGraphResponse* response,
StatusCallback done) override {
IssueRequest(request, response,
&grpc::WorkerService::Stub::AsyncRegisterGraph, done);
}
void DeregisterGraphAsync(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response,
StatusCallback done) override {
IssueRequest(request, response,
&grpc::WorkerService::Stub::AsyncDeregisterGraph, done);
}
void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
RunGraphResponse* response, StatusCallback done) override {
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncRunGraph,
done, call_opts);
}
void CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) override {
IssueRequest(request, response,
&grpc::WorkerService::Stub::AsyncCleanupGraph, done);
}
void CleanupAllAsync(const CleanupAllRequest* request,
CleanupAllResponse* response,
StatusCallback done) override {
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncCleanupAll,
done);
}
void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
RecvTensorResponse* response,
TensorBufAllocator allocator,
StatusCallback done) override {
VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
int64 start_usec = Env::Default()->NowMicros();
// Don't propagate dma_ok over gRPC.
RecvTensorRequest* req_copy = nullptr;
if (request->dma_ok()) {
req_copy = new RecvTensorRequest;
*req_copy = *request;
req_copy->set_dma_ok(false);
}
// Type-specialized logging for this method.
StatusCallback logging_callback = [this, request, req_copy, response, done,
start_usec](Status s) {
if (logger_->LoggingActive()) {
int64 end_usec = Env::Default()->NowMicros();
int64 step_id = request->step_id();
int64 bytes = response->tensor().ByteSize();
int64 send_start_usec = start_usec;
// If a send start time was reported by the other side, use
// that instead. Maybe we should mark the display if we're using
// our local time instead of the remote start time?
if (response->send_start_micros()) {
// send_start_micros is the timestamp taken when the remote
// machine began to send the RecvTensor response.
// Due to clock skew between source and dest machines, it is
// possible that send_start_micros can be larger than end_usec or
// less than start_usec.
// To respect causality, we enforce the invariants that the RecvTensor
// response can not have been sent before the RecvTensor request, and
// must have been sent before it was received.
send_start_usec = std::max(start_usec, response->send_start_micros());
send_start_usec = std::min(send_start_usec, end_usec - 1);
}
const string& key = request->rendezvous_key();
std::vector<string> key_parts = str_util::Split(key, ';');
if (key_parts.size() != 5) {
LOG(WARNING) << "Bad key: " << key;
} else {
logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
key_parts[3], // tensor name
key_parts[0], // src_device
key_parts[2], // dst_device
bytes);
}
}
VLOG(2) << "done callback, req: " << request->DebugString()
<< " response " << response->DebugString();
delete req_copy;
done(s);
};
IssueRequest(req_copy ? req_copy : request, response,
&grpc::WorkerService::Stub::AsyncRecvTensor, logging_callback,
call_opts);
}
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
StatusCallback done) override {
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncLogging,
done);
}
void TracingAsync(const TracingRequest* request, TracingResponse* response,
StatusCallback done) override {
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncTracing,
done);
}
private:
template <class RequestMessage, class ResponseMessage>
using AsyncMethod =
std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseMessage>> (
grpc::WorkerService::Stub::*)(::grpc::ClientContext*,
const RequestMessage&,
::grpc::CompletionQueue*);
// Utility method for issuing a generic asynchronous request. The
// given callback, `done`, will be called when the RPC completes.
template <class RequestMessage, class ResponseMessage>
void IssueRequest(const RequestMessage* request, ResponseMessage* response,
AsyncMethod<RequestMessage, ResponseMessage> async_method,
StatusCallback done, CallOptions* call_opts = nullptr) {
::grpc::ClientContext* context = new ::grpc::ClientContext;
if (call_opts) {
call_opts->SetCancelCallback([context]() { context->TryCancel(); });
}
auto rpc = (stub_.get()->*async_method)(context, *request, cq_).release();
GrpcClientCQTag* tag =
new GrpcClientCQTag(context, [rpc, done, call_opts](Status s) {
if (call_opts) {
call_opts->ClearCancelCallback();
}
delete rpc;
done(s);
});
rpc->Finish(response, tag->status(), tag);
}
std::unique_ptr<grpc::WorkerService::Stub> stub_;
::grpc::CompletionQueue* cq_;
// Support for logging.
WorkerCacheLogger* logger_;
bool retry_unavailable_;
TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
};
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue,
WorkerCacheLogger* logger) {
return new GrpcRemoteWorker(channel, completion_queue, logger);
}
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
#define THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
#include <memory>
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
namespace grpc {
class CompletionQueue;
}
namespace tensorflow {
class WorkerCacheLogger;
class WorkerInterface;
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue,
WorkerCacheLogger* logger);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_

View File

@ -0,0 +1,116 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include <memory>
#include "external/grpc/include/grpc++/grpc++.h"
#include "external/grpc/include/grpc++/security/credentials.h"
#include "external/grpc/include/grpc++/server_builder.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/master_session.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
void StartTensorFlowServer(const GrpcServerOptions& options) {
// The Thread destructor waits until all the thread terminates is
// done (i.e. forever).
std::unique_ptr<Thread> launcher_thread(Env::Default()->StartThread(
ThreadOptions(), "TF_service_launcher", [options]() {
// Configure the MasterEnv and WorkerEnv, which provide service-global
// context for the master and worker services, respectively.
// The master and worker share the same worker cache (for RPC
// connections to other workers) and devices (so that the master
// may run some ops locally as a "client" device). The master
// requires a device to serve as a "client device", so that remote
// devices can copy the feeds from the master.
MasterEnv master_env;
WorkerEnv worker_env;
master_env.env = Env::Default();
worker_env.env = Env::Default();
// Configure shared devices between master and worker.
string name_prefix =
strings::StrCat("/job:", options.job_name, "/replica:0", "/task:",
options.task_index);
DeviceFactory::AddDevices(options.default_session_options, name_prefix,
&master_env.local_devices);
worker_env.device_mgr = new DeviceMgr(master_env.local_devices);
string unused;
CHECK(DeviceNameUtils::SplitDeviceName(
master_env.local_devices[0]->name(), &worker_env.worker_name,
&unused));
GrpcChannelCache* channel_cache =
NewGrpcChannelCache(options.channel_spec);
int port;
const std::vector<string> host_port =
str_util::Split(channel_cache->TranslateTask(name_prefix), ':');
CHECK(str_util::NumericParse32(host_port[1], &port));
worker_env.worker_cache = NewGrpcWorkerCache(channel_cache);
// Finish setting up master environment.
master_env.ops = OpRegistry::Global();
master_env.worker_cache = worker_env.worker_cache;
master_env.master_session_factory = internal::NewMasterSession;
// Finish setting up worker environment.
worker_env.graph_mgr = new GraphMgr(&worker_env);
worker_env.rendezvous_mgr = new RpcRendezvousMgr(&worker_env);
worker_env.compute_pool = ComputePool(options.default_session_options);
// Build the gRPC server that will host both the master and the
// worker services.
::grpc::ServerBuilder builder;
builder.AddListeningPort(strings::StrCat("0.0.0.0:", port),
::grpc::InsecureServerCredentials());
auto master_service = NewGrpcMasterService(&master_env, &builder);
auto worker_service = NewGrpcWorkerService(&worker_env, &builder);
auto server_ = builder.BuildAndStart();
// Start threads to handle the incoming RPCs for the master and worker.
// NOTE(mrry): The Thread destructor waits until the thread terminates
// (i.e. forever in this case).
std::unique_ptr<Thread> master_thread(Env::Default()->StartThread(
ThreadOptions(), "TF_master_service",
[master_service]() { master_service->HandleRPCsLoop(); }));
std::unique_ptr<Thread> worker_thread(Env::Default()->StartThread(
ThreadOptions(), "TF_worker_service",
[worker_service]() { worker_service->HandleRPCsLoop(); }));
}));
}
} // end namespace tensorflow

View File

@ -0,0 +1,53 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
// Defines the configuration for a single task (typically a process)
// that is part of a gRPC-based TensorFlow cluster.
struct GrpcServerOptions {
// This identity of the job to which this task belongs. The names
// of the devices in this task will be prefixed with
// "/job:<job_name>/task:<task_index>"
string job_name;
int32 task_index = 0;
// A channel specification, which defines (i) the set of jobs that
// comprise the cluster, and (ii) within each job, the endpoints
// exposed by each task. NOTE: This spec also defines the endpoint
// on which this task will listen.
GrpcChannelSpec channel_spec;
// SessionOptions that will be used as defaults when configuring
// sessions in this task. `default_session_options.target` is
// ignored.
SessionOptions default_session_options;
};
// Starts a gRPC-based TensorFlow server with the given options.
// This function will not return.
void StartTensorFlowServer(const GrpcServerOptions& options);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_

View File

@ -0,0 +1,233 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
#include <unordered_map>
#include "tensorflow/core/common_runtime/session_factory.h"
#include "tensorflow/core/distributed_runtime/master_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/master.pb.h"
namespace tensorflow {
const size_t kSchemePrefix = sizeof("grpc://") - 1;
GrpcSession::GrpcSession(const SessionOptions& options)
: options_(options),
master_(NewGrpcMaster(
NewHostPortGrpcChannel(options.target.substr(kSchemePrefix)))),
current_graph_version_(-1) {}
GrpcSession::~GrpcSession() {}
namespace {
// Re-encodes constant represented in tensor proto into
// tensor_content, which is slightly better (less copies and lower peak
// memory usage) when used with rpc subsystems.
void ReEncodeConsts(GraphDef* gdef) {
for (NodeDef& ndef : *(gdef->mutable_node())) {
if (ndef.op() == "Const") {
TensorProto* proto = nullptr;
for (auto& attr : *ndef.mutable_attr()) {
if (attr.first == "value") {
proto = attr.second.mutable_tensor();
}
}
if (proto != nullptr && proto->tensor_content().empty() &&
proto->ByteSize() > 64) {
// If the constant is encoded with repeated proto fields and
// it is moderate large, we re-encode it in tensor_content as
// a Cord. This is mildly helpful for reducing the peak memory
// usage on the server side where GraphDef/NodeDef are copied
// quite often.
Tensor parsed(proto->dtype());
if (parsed.FromProto(*proto)) {
parsed.AsProtoTensorContent(proto);
}
}
}
}
}
} // namespace
Status GrpcSession::Create(const GraphDef& graph) {
if (!handle_.empty()) {
return errors::InvalidArgument("A session is alive.");
}
CreateSessionRequest req;
*req.mutable_config() = options_.config;
*req.mutable_graph_def() = graph;
ReEncodeConsts(req.mutable_graph_def());
CreateSessionResponse resp;
Status s = master_->CreateSession(&req, &resp);
if (s.ok()) {
mutex_lock l(mu_);
swap(handle_, *(resp.mutable_session_handle()));
current_graph_version_ = resp.graph_version();
}
return s;
}
Status GrpcSession::Extend(const GraphDef& graph) {
if (handle_.empty()) {
// Session was unitialized, so simply initialize the session with 'graph'.
return Create(graph);
}
mutex_lock l(mu_);
ExtendSessionRequest req;
req.set_session_handle(handle_);
*req.mutable_graph_def() = graph;
req.set_current_graph_version(current_graph_version_);
ExtendSessionResponse resp;
Status s = master_->ExtendSession(&req, &resp);
if (s.ok()) {
current_graph_version_ = resp.new_graph_version();
}
return s;
}
Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs) {
// Convert to proto
RunStepRequest req;
RunStepResponse resp;
for (const auto& it : inputs) {
Tensor input_tensor = it.second;
auto feed = req.add_feed();
feed->set_name(it.first);
TensorProto* proto = feed->mutable_tensor();
input_tensor.AsProtoTensorContent(proto);
}
// Build an index from fetch tensor name to offset.
std::unordered_map<string, int> output_name_to_offset;
for (const string& output_name : output_names) {
req.add_fetch(output_name);
output_name_to_offset.insert(
std::make_pair(output_name, output_name_to_offset.size()));
}
for (const string& target : target_nodes) {
req.add_target(target);
}
TF_RETURN_IF_ERROR(RunProto(&req, &resp));
if (!output_names.empty()) {
outputs->resize(output_names.size());
}
// Convert response back to Tensors in the correct order.
for (const NamedTensorProto& tensor : resp.tensor()) {
auto fetch_it = output_name_to_offset.find(tensor.name());
if (fetch_it == output_name_to_offset.end()) {
return errors::Internal("Received response for unrequested fetch: ",
tensor.name());
}
Tensor output;
if (!output.FromProto(tensor.tensor())) {
return errors::InvalidArgument("Could not parse returned proto for ",
tensor.name());
}
(*outputs)[fetch_it->second] = output;
}
return Status::OK();
}
Status GrpcSession::RunProto(RunStepRequest* req, RunStepResponse* resp) {
if (handle_.empty()) {
return errors::InvalidArgument("A session is not created yet....");
}
req->set_session_handle(handle_);
return master_->RunStep(req, resp);
}
Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
string* handle) {
return errors::Internal("Partial run is not supported for remote session.");
}
Status GrpcSession::PRun(const string& handle,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) {
return errors::Internal("Partial run is not supported for remote session.");
}
Status GrpcSession::Close() {
if (handle_.empty()) {
return errors::InvalidArgument("A session is not created yet....");
}
CloseSessionRequest req;
req.set_session_handle(handle_);
handle_.clear();
CloseSessionResponse resp;
return master_->CloseSession(&req, &resp);
}
std::vector<DeviceAttributes> GrpcSession::ListDevices() {
std::vector<DeviceAttributes> devices;
ListDevicesRequest req;
ListDevicesResponse resp;
Status s = master_->ListDevices(&req, &resp);
if (!s.ok()) {
LOG(ERROR) << "Could not list devices: " << s;
return devices;
}
for (const auto& device_attr : resp.local_device()) {
devices.push_back(device_attr);
}
for (const auto& device_attr : resp.remote_device()) {
devices.push_back(device_attr);
}
return devices;
}
class GrpcSessionFactory : public SessionFactory {
public:
bool AcceptsOptions(const SessionOptions& options) override {
return StringPiece(options.target).starts_with("grpc://");
}
Session* NewSession(const SessionOptions& options) override {
return new GrpcSession(options);
}
};
class GrpcSessionRegistrar {
public:
GrpcSessionRegistrar() {
SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
}
};
static GrpcSessionRegistrar registrar;
} // namespace tensorflow

View File

@ -0,0 +1,97 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
#include <string>
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class MasterInterface;
// A Session instance lets the caller drive a TensorFlow graph
// computation on potentially remote sets of devices. This is a thin
// wrapper around tensorflow::grpc::MasterService.
//
// Multiple threads must synchronize their accesses to a single
// session.
class GrpcSession : public Session {
public:
// Do not use; just present for easier swig wrapping.
explicit GrpcSession(const SessionOptions& options);
~GrpcSession() override;
// Creates a session with the "target". The session carries out
// the graph computation defined by "graph", and will have version
// number "initial_version".
Status Create(const GraphDef& graph) override;
Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs) override;
Status Extend(const GraphDef& graph) override;
Status Close() override;
// NOTE: This API is still experimental and may change.
::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
string* handle) override;
// NOTE: This API is still experimental and may change.
::tensorflow::Status PRun(
const string& handle,
const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) override;
std::vector<DeviceAttributes> ListDevices();
private:
SessionOptions options_;
std::unique_ptr<MasterInterface> master_;
mutex mu_;
// handle_ returned by the master to identify this session.
string handle_;
// The current version of the graph.
int64 current_graph_version_ GUARDED_BY(mu_);
Status RunProto(RunStepRequest* req, RunStepResponse* resp);
TF_DISALLOW_COPY_AND_ASSIGN(GrpcSession);
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_

View File

@ -0,0 +1,750 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/port.h"
namespace tensorflow {
static SessionOptions Devices(int num_cpus, int num_gpus) {
SessionOptions result;
(*result.config.mutable_device_count())["CPU"] = num_cpus;
(*result.config.mutable_device_count())["GPU"] = num_gpus;
return result;
}
void CreateGraphDef(GraphDef* graph_def, string node_names[3]) {
Graph graph(OpRegistry::Global());
Tensor a_tensor(DT_FLOAT, TensorShape({1, 2}));
test::FillValues<float>(&a_tensor, {1, 2});
Node* a = test::graph::Constant(&graph, a_tensor);
node_names[0] = a->name();
Tensor b_tensor(DT_FLOAT, TensorShape({2, 1}));
test::FillValues<float>(&b_tensor, {2, 1});
Node* b = test::graph::Constant(&graph, b_tensor);
node_names[1] = b->name();
Node* c = test::graph::Matmul(&graph, a, b, false, false);
node_names[2] = c->name();
test::graph::ToGraphDef(&graph, graph_def);
}
// Asserts that "val" is a single float tensor. The only float is
// "expected_val".
static void IsSingleFloatValue(const Tensor& val, float expected_val) {
ASSERT_EQ(val.dtype(), DT_FLOAT);
ASSERT_EQ(val.NumElements(), 1);
ASSERT_EQ(val.flat<float>()(0), expected_val);
}
static SessionOptions Options(const string& target, int placement_period) {
SessionOptions options;
// NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target
// string.
options.target = strings::StrCat("grpc://", target);
options.config.set_placement_period(placement_period);
return options;
}
static Session* NewRemote(const SessionOptions& options) {
return CHECK_NOTNULL(NewSession(options));
}
TEST(GrpcSessionTest, BasicNonProtoAPI) {
GraphDef graph;
string node_names[3];
// c = a * b
CreateGraphDef(&graph, node_names);
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
std::unique_ptr<Session> session(
NewRemote(Options(cluster->targets()[0], 1)));
ASSERT_TRUE(session != nullptr);
for (int iters = 0; iters < 25; ++iters) {
TF_CHECK_OK(session->Create(graph));
{
std::vector<std::pair<string, Tensor>> inputs;
TF_CHECK_OK(session->Run(inputs, {}, {}, {}));
}
{
// Just run to target node
std::vector<std::pair<string, Tensor>> inputs;
std::vector<string> targets = {node_names[2]};
TF_CHECK_OK(session->Run(inputs, {}, targets, nullptr));
}
{
// Run to a target node and a real tensor
std::vector<std::pair<string, Tensor>> inputs;
std::vector<string> names = {node_names[2] + ":0"};
std::vector<string> targets = {node_names[1]};
std::vector<Tensor> outputs;
TF_CHECK_OK(session->Run(inputs, names, targets, &outputs));
ASSERT_TRUE(outputs[0].IsInitialized());
ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
}
TF_CHECK_OK(session->Close());
}
}
TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
GraphDef graph;
string node_names[3];
// c = a * b
CreateGraphDef(&graph, node_names);
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
std::unique_ptr<Session> session(
NewRemote(Options(cluster->targets()[0], 1)));
ASSERT_TRUE(session != nullptr);
ASSERT_TRUE(session->Create(graph).ok());
// Test that the order of the output names matches the order of the
// returned Tensors.
std::vector<std::pair<string, Tensor>> inputs;
std::vector<string> names = {node_names[2] + ":0", node_names[0] + ":0",
node_names[1] + ":0"};
std::vector<string> target_ops = {node_names[1]};
std::vector<Tensor> outputs;
ASSERT_TRUE(session->Run(inputs, names, target_ops, &outputs).ok());
ASSERT_TRUE(outputs[0].IsInitialized());
ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
ASSERT_TRUE(outputs[1].IsInitialized());
ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
ASSERT_TRUE(outputs[2].IsInitialized());
ASSERT_EQ(2.0, outputs[2].flat<float>()(0));
ASSERT_TRUE(session->Close().ok());
}
TEST(GrpcSessionTest, NonLocalWithFilters) {
GraphDef graph;
string node_names[3];
// c = a * b
CreateGraphDef(&graph, node_names);
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
SessionOptions options;
options.target = strings::StrCat("grpc://", cluster->targets()[0]);
options.config.add_device_filters(cluster->devices()[0].name());
std::unique_ptr<Session> session(NewRemote(options));
ASSERT_TRUE(session != nullptr);
{
GraphDef graph_copy(graph);
graph::SetDefaultDevice(cluster->devices()[0].name(), &graph_copy);
TF_CHECK_OK(session->Create(graph_copy));
TF_CHECK_OK(session->Run({}, {}, {}, nullptr));
TF_CHECK_OK(session->Close());
}
{
GraphDef graph_copy(graph);
graph::SetDefaultDevice(cluster->devices()[1].name(), &graph_copy);
TF_CHECK_OK(session->Create(graph_copy));
auto status = session->Run({}, {}, {}, nullptr);
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
TF_CHECK_OK(session->Close());
}
}
// A = [3 2; -1 0]; x = rand(2, 1); We want to compute the largest
// eigenvalue for A, which is 2.0. Iteratively, we do
// repeat x = y / y.norm(); y = A * x; end
// At the end, we expect "lambda" converges to 2.0.
void FindMaxEigen(const string& target) {
Graph graph(OpRegistry::Global());
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
// Store rows [3, 2] and [-1, 0] in row major format.
test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
Node* a = test::graph::Constant(&graph, a_tensor);
// x is from the feed.
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
test::FillValues<float>(&x_tensor, {0, 0});
Node* x = test::graph::Constant(&graph, x_tensor);
// y = A * x
Node* y = test::graph::Matmul(&graph, a, x, false, false);
// y2 = y.^2
Node* y2 = test::graph::Unary(&graph, "Square", y);
// const tensor for reduction
Tensor rdim_tensor(DT_INT32, TensorShape({}));
rdim_tensor.scalar<int32>()() = 0;
Node* rdim = test::graph::Constant(&graph, rdim_tensor);
// y2_sum = sum(y2)
Node* y2_sum = test::graph::Reduce(&graph, "Sum", y2, rdim);
// y_norm = sqrt(y2_sum)
Node* y_norm = test::graph::Unary(&graph, "Sqrt", y2_sum);
// y_normalized = y ./ y_norm
Node* y_normalized = test::graph::Binary(&graph, "Div", y, y_norm);
GraphDef def;
test::graph::ToGraphDef(&graph, &def);
std::unique_ptr<Session> session(NewRemote(Options(target, 1)));
ASSERT_TRUE(session != nullptr);
TF_CHECK_OK(session->Create(def));
// Setup feeds and fetches.
float lambda;
Tensor feed_value(DT_FLOAT, TensorShape({2, 1}));
feed_value.matrix<float>()(0, 0) = -3.1415;
feed_value.matrix<float>()(1, 0) = +2.7183;
for (int i = 0; i < 25; ++i) {
std::vector<Tensor> outputs;
TF_CHECK_OK(session->Run({{x->name(), feed_value}},
{y->name(), y_normalized->name()}, {}, &outputs));
const Tensor& y = outputs[0];
const Tensor& y_normalized = outputs[1];
// Print out lambda, x, and y.
CHECK_EQ(2, feed_value.NumElements());
CHECK_EQ(2, y.NumElements());
lambda = y.flat<float>()(0) / feed_value.flat<float>()(0);
printf("%06d lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]\n", i,
lambda, feed_value.flat<float>()(0), feed_value.flat<float>()(1),
y.flat<float>()(0), y.flat<float>()(1));
// Copies y_normalized to *x.
feed_value = y_normalized;
}
EXPECT_NEAR(2.0, lambda, 1e-6);
}
TEST(FindMaxEigenTest, RemoteDevice) {
std::unique_ptr<test::TestCluster> cluster;
test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster);
FindMaxEigen(cluster->targets()[0]);
}
void SetDevice(GraphDef* graph, const string& name, const string& dev) {
for (int i = 0; i < graph->node_size(); ++i) {
if (graph->node(i).name() == name) {
graph->mutable_node(i)->set_device(dev);
return;
}
}
LOG(FATAL) << "Name '" << name << "' not found.";
}
TEST(GrpcSessionTest, MultiDevices) {
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
Graph graph(OpRegistry::Global());
const int kSize = 1048576;
// c = a * b = 2 * 3 * kSize
Tensor a_tensor(DT_FLOAT, TensorShape({1, kSize}));
Tensor b_tensor(DT_FLOAT, TensorShape({kSize, 1}));
for (int i = 0; i < kSize; ++i) {
a_tensor.flat<float>()(i) = 2;
b_tensor.flat<float>()(i) = 3;
}
Node* a = test::graph::Constant(&graph, a_tensor);
Node* b = test::graph::Constant(&graph, b_tensor);
Node* c = test::graph::Matmul(&graph, a, b, false, false);
GraphDef def;
test::graph::ToGraphDef(&graph, &def);
// In this test, we force each node (a, b, c) on every possible device.
// We test all possible cases.
for (const auto& a_dev : cluster->devices()) {
for (const auto& b_dev : cluster->devices()) {
for (const auto& c_dev : cluster->devices()) {
LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name()
<< " c: " << c_dev.name();
SetDevice(&def, a->name(), a_dev.name());
SetDevice(&def, b->name(), b_dev.name());
SetDevice(&def, c->name(), c_dev.name());
std::unique_ptr<Session> session(
NewRemote(Options(cluster->targets()[0], 1000)));
ASSERT_TRUE(session != nullptr);
TF_CHECK_OK(session->Create(def));
{
std::vector<Tensor> outputs;
TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs));
ASSERT_EQ(1, outputs.size());
IsSingleFloatValue(outputs[0], 6.0 * kSize);
}
TF_CHECK_OK(session->Close());
}
}
}
}
TEST(GrpcSessionTest, MultiDevices_String) {
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster));
std::unique_ptr<Session> session(
NewRemote(Options(cluster->targets()[0], 1000)));
ASSERT_TRUE(session != nullptr);
// b = a
Graph graph(OpRegistry::Global());
Tensor a_tensor(DT_STRING, TensorShape({2, 2}));
for (int i = 0; i < 4; ++i) {
a_tensor.flat<string>()(i) = "hello, world";
}
Node* a = test::graph::Constant(&graph, a_tensor);
Node* b = test::graph::Identity(&graph, a);
GraphDef def;
test::graph::ToGraphDef(&graph, &def);
// In this test, we force each node (a, b) on every possible device.
// We test all possible cases.
for (const auto& a_dev : cluster->devices()) {
for (const auto& b_dev : cluster->devices()) {
LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name();
SetDevice(&def, a->name(), a_dev.name());
SetDevice(&def, b->name(), b_dev.name());
TF_CHECK_OK(session->Create(def));
{
std::vector<Tensor> outputs;
Status s = session->Run({}, {b->name()}, {}, &outputs);
if (s.ok()) {
ASSERT_EQ(1, outputs.size());
ASSERT_EQ(outputs[0].dtype(), DT_STRING);
ASSERT_EQ(outputs[0].NumElements(), 4);
for (int i = 0; i < outputs[0].NumElements(); ++i) {
EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world");
}
} else {
LOG(ERROR) << "Error: " << s;
ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) ||
(b_dev.device_type() == DEVICE_GPU));
ASSERT_FALSE(s.ok());
}
}
TF_CHECK_OK(session->Close());
}
}
}
TEST(GrpcSessionTest, SendRecv_Node_Naming) {
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 3, &cluster));
std::unique_ptr<Session> session(
NewRemote(Options(cluster->targets()[0], 1)));
ASSERT_TRUE(session != nullptr);
// This test case needs at least 3 devices.
CHECK_GE(cluster->devices().size(), 3);
const DeviceAttributes& src = cluster->devices()[0];
const DeviceAttributes& dst0 = cluster->devices()[1];
const DeviceAttributes& dst1 = cluster->devices()[2];
LOG(INFO) << "src = " << src.name() << " dst0 = " << dst0.name()
<< " dst1 = " << dst1.name();
// Within the same session, we compute two subgraphs:
// 1) a on 'src' sends to b on 'dst0';
// 2) a on 'src' sends to c on 'dst1'.
Graph graph(OpRegistry::Global());
Tensor a_tensor(DT_FLOAT, TensorShape({1, 1}));
a_tensor.flat<float>()(0) = 100;
Node* a = test::graph::Constant(&graph, a_tensor);
Node* b = test::graph::Identity(&graph, a);
Node* c = test::graph::Identity(&graph, a);
GraphDef def;
test::graph::ToGraphDef(&graph, &def);
// The base graph have a, b, c, assigned to devices explicitly.
SetDevice(&def, a->name(), src.name());
SetDevice(&def, b->name(), dst0.name());
SetDevice(&def, c->name(), dst1.name());
TF_CHECK_OK(session->Create(def));
// Run subgraph a -> b, and fetch b.
{
std::vector<Tensor> outputs;
TF_CHECK_OK(session->Run({}, {b->name()}, {}, &outputs));
ASSERT_EQ(1, outputs.size());
IsSingleFloatValue(outputs[0], 100);
}
// Run subgraph a -> c, and fetch c.
{
std::vector<Tensor> outputs;
TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs));
ASSERT_EQ(1, outputs.size());
IsSingleFloatValue(outputs[0], 100);
}
TF_CHECK_OK(session->Close());
}
TEST(GrpcSessionTest, Error) {
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
const string& master = cluster->targets()[0];
const string& dev_a = cluster->devices()[0].name();
const string& dev_b = cluster->devices()[1].name();
LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b;
GraphDef gdef;
std::vector<string> fetches;
{
Graph g(OpRegistry::Global());
// a2 = a + error(a)
//
// Subgraph for "a" fails. The master will cancel the subgraph for
// "b" and then returns the Session::Run.
auto a = test::graph::Constant(&g, Tensor());
a->set_assigned_device_name(dev_a);
auto a_err = test::graph::Error(&g, a, "fantasia!");
a_err->set_assigned_device_name(dev_a);
auto a2 = test::graph::Add(&g, a, a_err);
a2->set_assigned_device_name(dev_a);
fetches.push_back(a2->name());
// b2 = b + delay(b)
//
// Subgraph for "b" sleeps at the node "b_delay". When the sleep
// finishes, the subgraph "b" will continue execution till it
// notices that it is cancelled. Meanwhile, subgraph's executor
// and its related state (registered ops) should still be alive.
auto b = test::graph::Constant(&g, Tensor());
b->set_assigned_device_name(dev_b);
auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000));
b_delay->set_assigned_device_name(dev_b);
auto b2 = test::graph::Add(&g, b, b_delay);
b2->set_assigned_device_name(dev_b);
fetches.push_back(b2->name());
test::graph::ToGraphDef(&g, &gdef);
}
std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
ASSERT_TRUE(session != nullptr);
TF_CHECK_OK(session->Create(gdef));
{
Status status = session->Run({}, fetches, {}, nullptr);
EXPECT_FALSE(status.ok());
EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
}
// session->Close() shall clean up all states related to the session->
// E.g., deregisters subgraph with workers, etc.
TF_CHECK_OK(session->Close());
// Sleep a bit so that most of asynchronous works finishes before
// the test process finishes.
Env::Default()->SleepForMicroseconds(2000000);
}
TEST(SessionTest, SharedVar) {
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
const string master = cluster->targets()[0];
CHECK_EQ(cluster->devices().size(), 1);
GraphDef gdef;
string init_name;
string inc_name;
string get_name;
{
Graph g(OpRegistry::Global());
Tensor one(DT_FLOAT, TensorShape({}));
one.scalar<float>()() = 1.0;
Node* var = test::graph::Var(&g, DT_FLOAT, one.shape());
Node* init = test::graph::Assign(&g, var, test::graph::Constant(&g, one));
init_name = init->name();
Node* update = test::graph::Assign(
&g, var, test::graph::Add(&g, var, test::graph::Constant(&g, one)));
inc_name = update->name();
get_name = var->name();
test::graph::ToGraphDef(&g, &gdef);
}
// Init a variable
{
Session* sess = NewRemote(Options(master, 1));
TF_CHECK_OK(sess->Create(gdef));
std::vector<std::pair<string, Tensor>> inp;
TF_CHECK_OK(sess->Run(inp, {}, {init_name}, nullptr));
TF_CHECK_OK(sess->Close());
delete sess;
}
for (int rep = 1; rep < 10; ++rep) {
// Update a variable
{
Session* sess = NewRemote(Options(master, 1));
TF_CHECK_OK(sess->Create(gdef));
std::vector<std::pair<string, Tensor>> inp;
TF_CHECK_OK(sess->Run(inp, {}, {inc_name}, nullptr));
TF_CHECK_OK(sess->Close());
delete sess;
}
// Gets the variable's value.
{
Session* sess = NewRemote(Options(master, 1));
TF_CHECK_OK(sess->Create(gdef));
std::vector<std::pair<string, Tensor>> inp;
std::vector<Tensor> ret;
TF_CHECK_OK(sess->Run(inp, {get_name}, {}, &ret));
ASSERT_EQ(ret.size(), 1);
EXPECT_EQ(ret[0].scalar<float>()(), 1.0 * (1 + rep));
TF_CHECK_OK(sess->Close());
delete sess;
}
}
}
void CreateInvalidGraph(const string& graph_def_ascii,
const string& error_substring) {
GraphDef graph;
CHECK(protobuf::TextFormat::ParseFromString(graph_def_ascii, &graph));
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
std::unique_ptr<Session> session(
NewRemote(Options(cluster->targets()[0], 1)));
Status s = session->Create(graph);
ASSERT_FALSE(s.ok());
EXPECT_NE(s.error_message().find(error_substring), string::npos);
}
TEST(SessionTest, InvalidOpName) {
CreateInvalidGraph(R"(
node {
name: 'a:b' op: 'Const'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'value' value {
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
float_val: [100] }
} }
}
)",
"Illegal op name");
CreateInvalidGraph(R"(
node {
name: 'a:0' op: 'Const'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'value' value {
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
float_val: [100] }
} }
}
)",
"Illegal op name");
CreateInvalidGraph(R"(
node {
name: '_a' op: 'Const'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'value' value {
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
float_val: [100] }
} }
}
)",
"Illegal op name");
}
TEST(SessionTest, InvalidOpInputName) {
CreateInvalidGraph(R"(
node {
name: 'a' op: 'const'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'value' value {
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
float_val: [100] }
} }
}
node {
name:'b' op:'MatMul' input:'a:first' input:'a'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
CreateInvalidGraph(R"(
node {
name: 'a' op: 'const'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'value' value {
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
float_val: [100] }
} }
}
node {
name:'b' op:'MatMul' input:'_a' input:'a'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
CreateInvalidGraph(R"(
node {
name: 'a' op: 'const'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'value' value {
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
float_val: [100] }
} }
}
node {
name:'b' op:'MatMul' input:'_a:0' input:'a'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
CreateInvalidGraph(R"(
node {
name: 'a' op: 'const'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'value' value {
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
float_val: [100] }
} }
}
node {
name:'b' op:'MatMul' input:'a:01' input:'a'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
"Illegal op input name");
}
TEST(SessionTest, ExtendValidation) {
GraphDef graph;
bool success = protobuf::TextFormat::ParseFromString(R"(
node {
name: 'a' op: 'Const'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'value' value {
tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
float_val: [100] }
} }
}
)",
&graph);
// NOTE(mrry): CHECK not done inline to avoid a compilation error in
// open-source (due to a multi-line string in a macro argument).
ASSERT_TRUE(success);
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
std::unique_ptr<Session> session(
NewRemote(Options(cluster->targets()[0], 1)));
TF_CHECK_OK(session->Create(graph));
// 1. Fail with an unknown input name.
GraphDef extension;
success = protobuf::TextFormat::ParseFromString(R"(
node {
name:'b' op:'MatMul' input:'a:first' input:'a'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);
ASSERT_TRUE(success);
Status s = session->Extend(extension);
ASSERT_FALSE(s.ok());
EXPECT_NE(s.error_message().find("Illegal op input name"), string::npos);
// 2. Succeed with a valid node.
success = protobuf::TextFormat::ParseFromString(R"(
node {
name:'b' op:'MatMul' input:'a' input:'a'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);
ASSERT_TRUE(success);
TF_CHECK_OK(session->Extend(extension));
// 2. Fail with a duplicate node.
success = protobuf::TextFormat::ParseFromString(R"(
node {
name:'b' op:'MatMul' input:'a' input:'a'
attr { key: 'T' value { type: DT_FLOAT } }
attr { key: 'transpose_a' value { b: false } }
attr { key: 'transpose_b' value { b: false } }
attr { key: '_kernel' value { s: 'eigen' } }
}
)",
&extension);
ASSERT_TRUE(success);
s = session->Extend(extension);
ASSERT_FALSE(s.ok());
EXPECT_NE(s.error_message().find("'b', which was created by a previous call"),
string::npos);
}
} // namespace tensorflow

View File

@ -0,0 +1,98 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <iostream>
#include "external/grpc/include/grpc++/grpc++.h"
#include "external/grpc/include/grpc++/security/credentials.h"
#include "external/grpc/include/grpc++/server_builder.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/command_line_flags.h"
// This binary starts a TensorFlow server (master and worker).
namespace tensorflow {
namespace {
Status ParseFlagsForTask(int argc, char* argv[], GrpcServerOptions* options) {
string cluster_spec;
const bool parse_result =
ParseFlags(&argc, argv, {Flag("cluster_spec", &cluster_spec), //
Flag("job_name", &options->job_name), //
Flag("task_id", &options->task_index)});
if (!parse_result) {
return errors::InvalidArgument("Error parsing command-line flags");
}
size_t my_num_tasks = 0;
for (const string& job_str : str_util::Split(cluster_spec, ',')) {
// Split each entry in the flag into 3 pieces, separated by "|".
const std::vector<string> job_pieces = str_util::Split(job_str, '|');
CHECK_EQ(2, job_pieces.size()) << job_str;
const string& job = job_pieces[0];
// Does a bit more validation of the tasks_per_replica.
const StringPiece spec = job_pieces[1];
// job_str is of form <job_name>|<host_ports>.
const std::vector<string> host_ports = str_util::Split(spec, ';');
size_t num_tasks = host_ports.size();
if (job == options->job_name) {
my_num_tasks = num_tasks;
}
TF_RETURN_IF_ERROR(
options->channel_spec.AddHostPortsJob(job, host_ports, num_tasks));
LOG(INFO) << "Peer " << job << " " << num_tasks << " {"
<< str_util::Join(host_ports, ", ") << "}";
}
if (my_num_tasks == 0) {
return errors::InvalidArgument("Job name \"", options->job_name,
"\" does not appear in the cluster spec");
}
if (options->task_index >= my_num_tasks) {
return errors::InvalidArgument("Task index ", options->task_index,
" is invalid (job \"", options->job_name,
"\" contains ", my_num_tasks, " tasks");
}
return Status::OK();
}
} // namespace
} // namespace tensorflow
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
tensorflow::GrpcServerOptions options;
tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &options);
if (!s.ok()) {
std::cerr << "ERROR: " << s.error_message() << std::endl;
std::cerr << "Usage: " << argv[0]
<< " --cluster_spec=SPEC --job_name=NAME --task_id=ID"
<< std::endl;
std::cerr << "Where:" << std::endl;
std::cerr << " SPEC is <JOB>(,<JOB>)*" << std::endl;
std::cerr << " JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*" << std::endl;
std::cerr << " NAME is a valid job name ([a-z][0-9a-z]*)" << std::endl;
std::cerr << " HOST is a hostname or IP address" << std::endl;
std::cerr << " PORT is a port number" << std::endl;
return -1;
}
tensorflow::StartTensorFlowServer(options);
}

View File

@ -0,0 +1,123 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "external/grpc/include/grpc++/grpc++.h"
#include "external/grpc/include/grpc++/security/credentials.h"
#include "external/grpc/include/grpc++/server_builder.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/master_session.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/command_line_flags.h"
// This binary starts a TensorFlow server (master and worker) for test purposes.
namespace tensorflow {
struct GrpcTaskOptions {
// This process belongs to the "job_name".
string job_name;
// This process is the task-th task within the replica. 0th, 1st,
// 2nd, etc.
int32 task = 0;
// Specification of peers.
GrpcChannelSpec channel_spec;
SessionOptions default_session_options;
};
Status StartTensorFlowServer(const TaskOptions& task_options) {
thread::ThreadPool* thread_pool =
new thread::ThreadPool(Env::Default(), "server", 1);
thread_pool->Schedule([argc, argv, task_options]() {
// This process provides both the worker service and the master
// service. We let these two services share the same channel cache
// (rpc connections) and cpu devices (used by the master as the
// client device). These client devices require a worker service
// so that remote devices can copy the feeds from the client
// device in the master.
tensorflow::MasterEnv master_env;
string name_prefix =
strings::StrCat("/job:", task_optionss.job_name, "/replica:0", "/task:",
task_options.task);
DeviceFactory::AddDevices(task_options.default_session_options, name_prefix,
&master_env.local_devices);
// Create the DeviceMgr before initializing the RPC layer, because that
// needs to know how many devices of each kind exist.
WorkerEnv worker_env;
worker_env.device_mgr = new DeviceMgr(master_env.local_devices);
// Finish setting up Env for Worker service.
string donotcare;
CHECK(DeviceNameUtils::SplitDeviceName(master_env.local_devices[0]->name(),
&worker_env.worker_name,
&donotcare));
worker_env.env = Env::Default();
GrpcChannelCache* channel_cache =
NewGrpcChannelCache(task_options.channel_spec);
string server_address = channel_cache->TranslateTask(name_prefix);
worker_env.worker_cache = NewGrpcWorkerCache(channel_cache);
worker_env.graph_mgr = new GraphMgr(&worker_env);
worker_env.rendezvous_mgr = new RpcRendezvousMgr(&worker_env);
worker_env.compute_pool = ComputePool(task_options.default_session_options);
// Finish setting up Env for Master service.
master_env.env = Env::Default();
master_env.ops = OpRegistry::Global();
master_env.worker_cache = worker_env.worker_cache;
master_env.master_session_factory = internal::NewMasterSession;
::grpc::ServerBuilder builder;
builder.AddListeningPort(server_address,
::grpc::InsecureServerCredentials());
auto master_service = NewGrpcMasterService(&master_env, &builder);
auto worker_service = NewGrpcWorkerService(&worker_env, &builder);
// Finally assemble the server.
auto server_ = builder.BuildAndStart();
std::unique_ptr<Thread> master_thread(Env::Default()->StartThread(
ThreadOptions(), "master_service_thread",
[master_service]() { master_service->HandleRPCsLoop(); }));
std::unique_ptr<Thread> worker_thread(Env::Default()->StartThread(
ThreadOptions(), "worker_service_thread",
[worker_service]() { worker_service->HandleRPCsLoop(); }));
});
// The ThreadPool destructor waits until all work is done (i.e. forever).
delete thread_pool;
return Status::OK();
}
} // end namespace tensorflow

View File

@ -0,0 +1,84 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace test {
Status TestCluster::MakeTestCluster(const SessionOptions& options, int n,
std::unique_ptr<TestCluster>* out_cluster) {
CHECK_GE(n, 1);
std::unique_ptr<TestCluster> ret(new TestCluster);
ret->targets_.resize(n);
std::vector<int> port(n);
for (int i = 0; i < n; ++i) {
port[i] = testing::PickUnusedPortOrDie();
ret->targets_[i] = strings::StrCat("localhost:", port[i]);
}
const string tf_jobs = strings::StrCat("--tf_jobs=localhost|",
str_util::Join(ret->targets_, ";"));
int num_cpus = 1;
int num_gpus = 0;
auto iter = options.config.device_count().find("CPU");
if (iter != options.config.device_count().end()) {
num_cpus = iter->second;
}
iter = options.config.device_count().find("GPU");
if (iter != options.config.device_count().end()) {
num_gpus = iter->second;
}
for (int i = 0; i < n; ++i) {
const std::vector<string> argv(
{strings::StrCat(testing::TensorFlowSrcRoot(),
"/core/distributed_runtime/rpc/grpc_testlib_server"),
/* see grpc_testlib_server.cc for flags */
tf_jobs, "--tf_job=localhost", strings::StrCat("--tf_task=", i),
strings::StrCat("--num_cpus=", num_cpus),
strings::StrCat("--num_gpus=", num_gpus)});
ret->subprocesses_.emplace_back(testing::CreateSubProcess(argv));
bool success = ret->subprocesses_[i]->Start();
if (!success) {
return errors::Internal("Could not start subprocess");
}
}
SessionOptions options_copy(options);
options_copy.target = strings::StrCat("grpc://", ret->targets_[0]);
std::unique_ptr<GrpcSession> session(new GrpcSession(options_copy));
std::vector<DeviceAttributes> device_attributes;
ret->devices_ = session->ListDevices();
*out_cluster = std::move(ret);
return Status::OK();
}
TestCluster::~TestCluster() {
for (auto& subprocess : subprocesses_) {
subprocess->Kill(9);
}
}
} // end namespace test
} // end namespace tensorflow

View File

@ -0,0 +1,73 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class Device;
namespace test {
// Provides a handle to a set of TensorFlow servers (masters and
// workers) for testing purposes.
//
// This class currently runs the servers in separate processes; the
// lifetime of this object is coterminous with the lifetimes of those
// processes.
class TestCluster {
public:
// Creates a new test cluster based on the given `options` (which
// configure the number of devices of each type) and a count of
// processes `n`. On success, the test cluster is stored in
// *out_cluster, and this function returns OK. Otherwise an error is
// returned.
static Status MakeTestCluster(const SessionOptions& options, int n,
std::unique_ptr<TestCluster>* out_cluster);
~TestCluster();
// Returns a vector of string "<hostname>:<port>" pairs that may be
// used as targets to construct a GrpcSession.
const std::vector<string>& targets() const { return targets_; }
// Returns a vector of devices available in this test cluster.
const std::vector<DeviceAttributes>& devices() const { return devices_; }
private:
TestCluster() = default;
std::vector<std::unique_ptr<testing::SubProcess>> subprocesses_;
std::vector<string> targets_;
std::vector<DeviceAttributes> devices_;
TF_DISALLOW_COPY_AND_ASSIGN(TestCluster);
};
} // end namespace test
} // end namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_

View File

@ -0,0 +1,91 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace test {
// ErrorOp::Compute returns an error.
REGISTER_OP("Error")
.Input("in: T")
.Output("out: T")
.Attr("T: type")
.Attr("message: string");
class ErrorOp : public OpKernel {
public:
explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &errmsg_));
}
void Compute(OpKernelContext* ctx) override {
ctx->SetStatus(errors::Internal(errmsg_));
}
private:
string errmsg_;
};
REGISTER_KERNEL_BUILDER(Name("Error").Device(DEVICE_CPU), ErrorOp);
REGISTER_OP("InvalidRefType")
.Output("out: Ref(TIn)")
.Attr("TIn: type")
.Attr("TOut: type");
class InvalidRefType : public OpKernel {
public:
explicit InvalidRefType(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("TOut", &dtout_));
output_ = Tensor(dtout_, TensorShape({}));
}
void Compute(OpKernelContext* ctx) override {
ctx->set_output_ref(0, &mu_, &output_);
}
private:
DataType dtout_;
mutex mu_;
Tensor output_;
};
REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU),
InvalidRefType);
// DelayOp::AsyncCompute sleeps for "micros"-econd and then returns
// its input.
REGISTER_OP("Delay")
.Input("in: T")
.Output("out: T")
.Attr("T: type")
.Attr("micros: int");
class DelayOp : public AsyncOpKernel {
public:
explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("micros", &micros_));
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
ctx->set_output(0, ctx->input(0));
ctx->env()->SchedClosureAfter(micros_, done);
}
private:
int64 micros_;
};
REGISTER_KERNEL_BUILDER(Name("Delay").Device(DEVICE_CPU), DelayOp);
} // namespace test
} // namespace tensorflow

View File

@ -0,0 +1,92 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "external/grpc/include/grpc++/grpc++.h"
#include "external/grpc/include/grpc++/security/credentials.h"
#include "external/grpc/include/grpc++/server_builder.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/command_line_flags.h"
// This binary starts a TensorFlow server (master and worker) for test purposes.
namespace tensorflow {
namespace {
Status ParseFlagsForTask(int argc, char* argv[], GrpcServerOptions* options) {
string job_spec;
int num_cpus = 1;
int num_gpus = 0;
const bool parse_result =
ParseFlags(&argc, argv, {Flag("tf_jobs", &job_spec), //
Flag("tf_job", &options->job_name), //
Flag("tf_task", &options->task_index), //
Flag("num_cpus", &num_cpus), //
Flag("num_gpus", &num_gpus)});
if (!parse_result) {
return errors::InvalidArgument("Error parsing command-line flags");
}
uint32 my_tasks_per_replica = 0;
for (const string& job_str : str_util::Split(job_spec, ',')) {
// Split each entry in the flag into 3 pieces, separated by "|".
const std::vector<string> job_pieces = str_util::Split(job_str, '|');
CHECK_EQ(2, job_pieces.size()) << job_str;
const string& job = job_pieces[0];
// Does a bit more validation of the tasks_per_replica.
const StringPiece spec = job_pieces[1];
// job_str is of form <job_name>|<host_ports>.
const std::vector<string> host_ports = str_util::Split(spec, ';');
uint32 tasks_per_replica = host_ports.size();
if (job == options->job_name) {
my_tasks_per_replica = tasks_per_replica;
}
TF_RETURN_IF_ERROR(options->channel_spec.AddHostPortsJob(
job, host_ports, tasks_per_replica));
LOG(INFO) << "Peer " << job << " " << tasks_per_replica << " {"
<< str_util::Join(host_ports, ", ") << "}";
}
if (my_tasks_per_replica == 0) {
return errors::InvalidArgument("Invalid job specification");
}
(*options->default_session_options.config.mutable_device_count())["CPU"] =
num_cpus;
(*options->default_session_options.config.mutable_device_count())["GPU"] =
num_gpus;
return Status::OK();
}
} // namespace
} // namespace tensorflow
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
tensorflow::GrpcServerOptions options;
tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &options);
if (!s.ok()) {
LOG(ERROR) << "Could not parse flags: " << s.error_message();
return -1;
}
tensorflow::StartTensorFlowServer(options);
// NOTE(mrry): Unreachable code.
return 0;
}

View File

@ -0,0 +1,48 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
#include <memory>
#include "external/grpc/include/grpc++/grpc++.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
inline Status FromGrpcStatus(const ::grpc::Status& s) {
if (s.ok()) {
return Status::OK();
} else {
return Status(static_cast<tensorflow::error::Code>(s.error_code()),
s.error_message());
}
}
inline ::grpc::Status ToGrpcStatus(const ::tensorflow::Status& s) {
if (s.ok()) {
return ::grpc::Status::OK;
} else {
return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()),
s.error_message());
}
}
typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_

View File

@ -0,0 +1,85 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
#include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
class GrpcWorkerCache : public WorkerCachePartial {
public:
explicit GrpcWorkerCache(GrpcChannelCache* channel_cache)
: channel_cache_(channel_cache) {
// TODO(mrry): Investigate possible performance improvements by
// replacing this thread with a threadpool.
polling_thread_ = Env::Default()->StartThread(
ThreadOptions(), "grpc_worker_cache", [this]() {
void* tag;
bool ok;
while (completion_queue_.Next(&tag, &ok)) {
GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
callback_tag->OnCompleted(ok);
delete callback_tag;
}
});
}
// Explicit destructor to control destruction order.
~GrpcWorkerCache() override {
completion_queue_.Shutdown();
delete polling_thread_; // Blocks until thread exits.
delete channel_cache_;
}
void ListWorkers(std::vector<string>* workers) override {
channel_cache_->ListWorkers(workers);
}
WorkerInterface* CreateWorker(const string& target) override {
SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
CHECK(channel) << "Channel was null";
if (!channel) return nullptr;
WorkerInterface* ret =
NewGrpcRemoteWorker(channel, &completion_queue_, &logger_);
return ret;
}
void SetLogging(bool v) override { logger_.SetLogging(v); }
void ClearLogs() override { logger_.ClearLogs(); }
bool RetrieveLogs(int64 step_id, StepStats* ss) override {
return logger_.RetrieveLogs(step_id, ss);
}
private:
GrpcChannelCache* channel_cache_; // Owned.
::grpc::CompletionQueue completion_queue_;
Thread* polling_thread_; // Owned.
WorkerCacheLogger logger_;
};
WorkerCacheInterface* NewGrpcWorkerCache(GrpcChannelCache* cc) {
return new GrpcWorkerCache(cc);
}
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
namespace tensorflow {
// The returned WorkerCacheInterface object takes the ownership of "cc".
WorkerCacheInterface* NewGrpcWorkerCache(GrpcChannelCache* cc);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_

View File

@ -0,0 +1,415 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include <deque>
#include "external/grpc/include/grpc++/server_builder.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/protobuf/worker_service.grpc.pb.h"
#include "tensorflow/core/protobuf/worker_service.pb.h"
namespace tensorflow {
namespace {
static Tensor empty_tensor(DT_FLOAT);
class GrpcWorkerService : public AsyncServiceInterface {
public:
GrpcWorkerService(WorkerEnv* env, ::grpc::ServerBuilder* builder)
: env_(env), cancellation_manager_(new CancellationManager) {
builder->RegisterService(&worker_service_);
cq_ = builder->AddCompletionQueue().release();
}
~GrpcWorkerService() { delete cq_; }
// This macro creates a new request for the given RPC method name
// (e.g., `ENQUEUE_REQUEST(GetStatus);`), and enqueues it on
// `this->cq_`.
//
// This macro is invoked one or more times for each RPC method to
// ensure that there are sufficient completion queue entries to
// handle incoming requests without blocking.
//
// The implementation of the request handler for each RPC method
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
// to keep accepting new requests.
#define ENQUEUE_REQUEST(method) \
do { \
Call<GrpcWorkerService, grpc::WorkerService::AsyncService, \
method##Request, method##Response>:: \
EnqueueRequest(&worker_service_, cq_, \
&grpc::WorkerService::AsyncService::Request##method, \
&GrpcWorkerService::method##Handler); \
} while (0)
// This method blocks forever handling requests from the completion queue.
void HandleRPCsLoop() {
// TODO(mrry): This may require performance engineering. We can
// add more threads to service the completion queue, and add more
// of various request types if they are short and frequent.
// Currently we allow unbounded numbers of pending calls for each
// method, by re-enqueuing a request before the previous one
// completes, and we may decide to bound some of the request
// types.
ENQUEUE_REQUEST(GetStatus);
ENQUEUE_REQUEST(CleanupAll);
ENQUEUE_REQUEST(RegisterGraph);
ENQUEUE_REQUEST(DeregisterGraph);
// TODO(mrry): Consider enqueuing more of these request types.
ENQUEUE_REQUEST(RecvTensor);
ENQUEUE_REQUEST(RunGraph);
ENQUEUE_REQUEST(CleanupGraph);
ENQUEUE_REQUEST(Logging);
ENQUEUE_REQUEST(Tracing);
void* tag;
bool ok;
while (cq_->Next(&tag, &ok)) {
UntypedCall<GrpcWorkerService>::Tag* callback_tag =
static_cast<UntypedCall<GrpcWorkerService>::Tag*>(tag);
callback_tag->OnCompleted(this, ok);
delete callback_tag;
}
}
private:
WorkerEnv* env_; // Not owned.
::grpc::ServerCompletionQueue* cq_; // Owned.
grpc::WorkerService::AsyncService worker_service_;
mutex mu_;
CancellationManager* cancellation_manager_ GUARDED_BY(mu_);
// The following section contains one request handler method per
// RPC. The The `FooHandler` method is called (indirectly) by
// `HandleRPCsLoop()` when the next Foo RPC is received. Each
// `FooHandler` call schedules a closure on `env_->compute_pool`,
// and is responsible for requesting the next Foo call by calling
// `ENQUEUE_REQUEST(Foo)`.
template <class RequestMessage, class ResponseMessage>
using WorkerCall = Call<GrpcWorkerService, grpc::WorkerService::AsyncService,
RequestMessage, ResponseMessage>;
void GetStatusHandler(WorkerCall<GetStatusRequest, GetStatusResponse>* call) {
env_->compute_pool->Schedule([this, call]() {
DeviceMgr* dm = env_->device_mgr;
std::vector<DeviceAttributes> devices;
dm->ListDeviceAttributes(&devices);
call->response.mutable_device_attributes()->Reserve(devices.size());
for (size_t i = 0; i < devices.size(); i++) {
call->response.add_device_attributes()->Swap(&devices[i]);
}
call->SendResponse(::grpc::Status::OK);
});
ENQUEUE_REQUEST(GetStatus);
}
void CleanupAllHandler(
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
env_->compute_pool->Schedule([this, call]() {
std::vector<string> containers;
for (const auto& c : call->request.container()) containers.push_back(c);
env_->device_mgr->ClearContainers(containers);
call->SendResponse(::grpc::Status::OK);
});
ENQUEUE_REQUEST(CleanupAll);
}
void RegisterGraphHandler(
WorkerCall<RegisterGraphRequest, RegisterGraphResponse>* call) {
env_->compute_pool->Schedule([this, call]() {
Status s = env_->graph_mgr->Register(
call->request.session_handle(), call->request.graph_def(),
call->request.graph_options(), call->response.mutable_graph_handle());
call->SendResponse(ToGrpcStatus(s));
});
ENQUEUE_REQUEST(RegisterGraph);
}
void DeregisterGraphHandler(
WorkerCall<DeregisterGraphRequest, DeregisterGraphResponse>* call) {
env_->compute_pool->Schedule([this, call]() {
Status s = env_->graph_mgr->Deregister(call->request.graph_handle());
call->SendResponse(ToGrpcStatus(s));
});
ENQUEUE_REQUEST(DeregisterGraph);
}
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
ENQUEUE_REQUEST(RunGraph);
}
void RecvTensorHandler(
WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) {
env_->compute_pool->Schedule([this, call]() { DoRecvTensor(call); });
ENQUEUE_REQUEST(RecvTensor);
}
void CleanupGraphHandler(
WorkerCall<CleanupGraphRequest, CleanupGraphResponse>* call) {
env_->compute_pool->Schedule([this, call]() {
const int64 step_id = call->request.step_id();
env_->rendezvous_mgr->Cleanup(step_id);
call->SendResponse(::grpc::Status::OK);
});
ENQUEUE_REQUEST(CleanupGraph);
}
void LoggingHandler(WorkerCall<LoggingRequest, LoggingResponse>* call) {
env_->compute_pool->Schedule([this, call]() {
Status s = DoLogging(call);
call->SendResponse(ToGrpcStatus(s));
});
ENQUEUE_REQUEST(Logging);
}
void TracingHandler(WorkerCall<TracingRequest, TracingResponse>* call) {
SchedClosure([this, call]() {
Status s = DoTracing(call);
call->SendResponse(ToGrpcStatus(s));
});
ENQUEUE_REQUEST(Tracing);
}
#undef ENQUEUE_REQUEST
private:
// The following section contains the implementation of RunGraph()
// RecvTensor(), Logging(), and Tracing(), which are the four
// non-trivial and potentially long-running RPCs performed by a
// TensorFlow worker.
void AbortStep(int64 step_id) {
Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
// Delay a bit before aborting the step. This way, the root
// cause may return first back to the client instead of this
// cancellation generated abort error.
rendez->StartAbort(errors::Aborted("Step ", step_id));
rendez->Unref();
});
}
Status PrepareRunGraph(const RunGraphRequest& req, GraphMgr::NamedTensors* in,
GraphMgr::NamedTensors* out) {
if (req.send_size() > 0) {
// TODO(zhifengc): Let the caller decide on which device to
// allocate the tensor.
Device* cpu_dev = nullptr;
TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice("CPU:0", &cpu_dev));
AllocatorAttributes alloc_attrs;
Tensor val;
for (const NamedTensor& entry : req.send()) {
TF_RETURN_IF_ERROR(
cpu_dev->MakeTensorFromProto(entry.val(), alloc_attrs, &val));
in->insert({entry.key(), val});
}
}
for (const string& key : req.recv_key()) {
out->insert({key, empty_tensor});
}
return Status::OK();
}
void DoRunGraph(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
const int64 step_id = call->request.step_id();
TRACEPRINTF("RunGraph: %lld", step_id);
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(call->request, &in, out);
if (!s.ok()) {
delete out;
call->SendResponse(ToGrpcStatus(s));
return;
}
StepStatsCollector* collector = nullptr;
// TODO(mrry): Collect results from a profiler if available.
CancellationManager* cm = new CancellationManager;
call->SetCancelCallback([this, cm, step_id]() {
cm->StartCancel();
AbortStep(step_id);
});
CancellationToken token;
{
mutex_lock l(mu_);
token = cancellation_manager_->get_cancellation_token();
cancellation_manager_->RegisterCallback(token,
[cm]() { cm->StartCancel(); });
}
env_->graph_mgr->ExecuteAsync(
call->request.graph_handle(), step_id, call->request.exec_opts(),
collector, cm, in, out, [this, call, cm, out, token](Status s) {
call->ClearCancelCallback();
{
mutex_lock l(mu_);
cancellation_manager_->DeregisterCallback(token);
}
delete cm;
if (s.ok()) {
for (const auto& p : *out) {
const string& key = p.first;
const Tensor& val = p.second;
auto* recv = call->response.add_recv();
recv->set_key(key);
// TODO(zhifengc): Deal with gpu -> cpu copy.
TensorProto* proto = recv->mutable_val();
val.AsProtoField(proto);
}
}
delete out;
call->SendResponse(ToGrpcStatus(s));
});
}
// Helper for RecvTensor. Validates "key" and returns the source
// device in "*src_dev".
Status PrepareRecvTensor(const string& key, Device** src_dev) {
// Validate the key.
Rendezvous::ParsedKey parsed;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
// Figures out which device the tensor is hosted on.
TF_RETURN_IF_ERROR(
env_->device_mgr->LookupDevice(parsed.src_device, src_dev));
// Does the device have the right incarnation number we expect?
if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
return errors::Aborted(
"RecvTensor expects a different device incarnation: ",
parsed.src_incarnation, " vs. ",
(*src_dev)->attributes().incarnation(),
". Your worker job was probably restarted. Check your "
"worker job for the reason why it was restarted.");
}
return Status::OK();
}
void DoRecvTensor(WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) {
const int64 step_id = call->request.step_id();
const string& key = call->request.rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Device* src_dev = nullptr;
Status s = PrepareRecvTensor(key, &src_dev);
if (!s.ok()) {
call->SendResponse(ToGrpcStatus(s));
return;
}
// Request the tensor associated with the rendezvous key. Any time
// while waiting for the tensor to be produced, up until the start
// of execution of the callback lambda body below, an RPC
// cancellation should abort the rendezvous.
call->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
env_->rendezvous_mgr->RecvLocalAsync(
step_id, key,
[this, call, src_dev](const Status& status,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& val, const bool is_dead) {
call->ClearCancelCallback();
Status s = status;
if (s.ok()) {
// DMA can only be used for Tensors that do not fall into
// the following three odd edge cases: 1) a zero-size
// buffer, 2) a dead tensor which has an uninit value, and
// 3) the tensor has the on_host allocation attribute,
// i.e. it's in CPU RAM *independent of its assigned
// device type*.
// const size_t bytes = is_dead ? 0 : val.TotalBytes();
const bool on_host = send_args.alloc_attrs.on_host();
const DeviceContext* send_dev_context = send_args.device_context;
call->response.set_is_dead(is_dead);
StatusCallback response_ready = [call](const Status& s) {
// The value is now ready to be returned on the wire.
call->response.set_send_start_micros(Env::Default()->NowMicros());
call->SendResponse(ToGrpcStatus(s));
};
{
// Non-DMA cases.
if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
CHECK(send_dev_context)
<< "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
// "val" is on a GPU. Uses GPUUtil to fill the response proto.
GPUUtil::SetProtoFromGPU(val, src_dev, send_dev_context,
call->response.mutable_tensor(),
is_dead, response_ready);
} else {
// "val" is in CPU memory.
TensorProto* proto = call->response.mutable_tensor();
val.AsProtoTensorContent(proto);
response_ready(Status::OK());
}
}
} else {
// !s.ok()
call->SendResponse(ToGrpcStatus(s));
}
});
}
Status DoLogging(WorkerCall<LoggingRequest, LoggingResponse>* call) {
// TODO(mrry): Platform-specific tracing support.
return errors::Unimplemented("Logging");
}
Status DoTracing(WorkerCall<TracingRequest, TracingResponse>* call) {
// TODO(mrry): Platform-specific tracing support.
return errors::Unimplemented("Tracing");
}
TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
};
} // namespace
AsyncServiceInterface* NewGrpcWorkerService(WorkerEnv* env,
::grpc::ServerBuilder* builder) {
return new GrpcWorkerService(env, builder);
}
} // namespace tensorflow

View File

@ -0,0 +1,34 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
namespace grpc {
class ServerBuilder;
} // namespace grpc
namespace tensorflow {
class AsyncServiceInterface;
class WorkerEnv;
// Returns an implementation of WorkerService rpc service.
AsyncServiceInterface* NewGrpcWorkerService(WorkerEnv* env,
::grpc::ServerBuilder* builder);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_

View File

@ -0,0 +1,196 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include <unordered_set>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
class RpcRemoteRendezvous : public BaseRemoteRendezvous {
public:
RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
: BaseRemoteRendezvous(env, step_id, false) {}
protected:
void RecvFromRemoteAsync(const string& key,
const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) override;
private:
~RpcRemoteRendezvous() override {}
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
};
// Used only to retrieve tensors from remote processes.
class RpcRecvTensorCall : public BaseRecvTensorCall {
public:
RpcRecvTensorCall(WorkerCacheInterface* wc, WorkerInterface* wi,
int64 step_id, const string& key,
const string& remote_dev, Allocator* allocator,
Device* dst_device)
: wi_(wi),
wc_(wc),
remote_dev_(remote_dev),
allocator_(allocator),
dst_(dst_device) {
req_.set_step_id(step_id);
req_.set_rendezvous_key(key);
}
~RpcRecvTensorCall() override { delete wi_; }
void Start(std::function<void()> recv_done) override {
StartRTCall(recv_done);
}
void StartAbort(const Status& s) override {
{
mutex_lock l(mu_);
status_.Update(s);
}
opts_.StartCancel();
}
Status status() const override {
mutex_lock l(mu_);
return status_;
}
const TensorProto& tensor_proto() const { return resp_.tensor(); }
const RecvTensorResponse& response() const { return resp_; }
bool is_dead() const { return resp_.is_dead(); }
private:
// Start the main RecvTensor call, checking for an async abort.
void StartRTCall(std::function<void()> recv_done) {
wi_->RecvTensorAsync(&opts_, &req_, &resp_,
nullptr /* TensorBufAllocator */,
// done callback
[this, recv_done](const Status& s) {
{
mutex_lock l(mu_);
status_.Update(s);
}
recv_done();
});
}
WorkerInterface* wi_; // Owned.
WorkerCacheInterface* wc_; // Not owned.
string remote_dev_;
Allocator* allocator_;
Device* dst_;
CallOptions opts_;
RecvTensorRequest req_;
RecvTensorResponse resp_;
mutable mutex mu_;
Status status_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
};
void RpcRemoteRendezvous::RecvFromRemoteAsync(
const string& key, const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& recv_args, DoneCallback done) {
Status s;
// key.src_device identifies a remote device.
string src_worker;
string src_rel_device;
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker,
&src_rel_device)) {
s = errors::Internal(parsed.src_device,
" is invalid remote source device.");
}
WorkerCacheInterface* worker_cache = env_->worker_cache;
if (s.ok() && worker_cache == nullptr) {
s = errors::Internal("No remote worker cache available.");
}
WorkerInterface* rwi = env_->worker_cache->CreateWorker(src_worker);
if (s.ok() && rwi == nullptr) {
s = errors::Internal("No worker known as ", src_worker);
}
Device* dst_device;
if (s.ok()) {
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
}
if (!s.ok()) {
done(s, Args(), recv_args, Tensor{}, false);
return;
}
Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs);
// Prepare a RecvTensor call that can handle being aborted.
RpcRecvTensorCall* call =
new RpcRecvTensorCall(worker_cache, rwi, step_id_, key,
parsed.src_device, allocator, dst_device);
// Record "call" in active_ so that it can be aborted cleanly.
RegisterCall(call);
// Start "call".
call->Start([this, call, parsed, recv_args, done]() {
// Removes "call" from active_. Prevent StartAbort().
DeregisterCall(call);
// If StartAbort was called prior to DeregisterCall, then the
// current status should be bad.
Status s = call->status();
Tensor val;
if (s.ok()) {
Device* dst_device;
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
if (s.ok()) {
s = dst_device->MakeTensorFromProto(call->tensor_proto(),
recv_args.alloc_attrs, &val);
}
}
done(s, Args(), recv_args, val, call->is_dead());
delete call;
});
}
} // namespace
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
const WorkerEnv* worker_env) {
return new RpcRemoteRendezvous(worker_env, step_id);
}
} // end namespace tensorflow

View File

@ -0,0 +1,57 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id"
// corresponds to one local rendezvous instance managed by a
// RendezvousMgr.
//
// E.g.,
// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
// fork execution of an graph executor using "rendez" on thread 1;
// fork execution of another graph executor using "rendez" on thread 2;
// ...
// join threads 1 and 2;
//
// In the example above, execution in thread 1 and 2 communicates with
// each other by send/recv operations through the "rend".
//
// Tensors sent and recved through rendezvous managed by this
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
class RpcRendezvousMgr : public BaseRendezvousMgr {
public:
explicit RpcRendezvousMgr(const WorkerEnv* env) : BaseRendezvousMgr(env) {}
protected:
BaseRemoteRendezvous* Create(int64 step_id,
const WorkerEnv* worker_env) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
};
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_

View File

@ -0,0 +1,172 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
// string -> Tensor<string>
Tensor V(const string& content) {
Tensor tensor(DT_STRING, TensorShape({}));
tensor.scalar<string>()() = content;
return tensor;
}
// Tensor<string> -> string
string V(const Tensor& tensor) {
CHECK_EQ(tensor.dtype(), DT_STRING);
CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
return tensor.scalar<string>()();
}
TEST(RpcRendezvousMgrTest, LocalSendRecv) {
WorkerEnv env;
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
RpcRendezvousMgr rmgr(&env);
const int64 step_id = 123;
const string key = Rendezvous::CreateKey(
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
{
Rendezvous* rendez = rmgr.Find(step_id);
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
}
{
Tensor val(DT_FLOAT);
bool val_dead = false;
TF_ASSERT_OK(rmgr.RecvLocal(step_id, key, &val, &val_dead));
EXPECT_EQ(V(val), "peach");
}
rmgr.Cleanup(step_id);
}
TEST(RpcRendezvousMgrTest, LocalAbort) {
WorkerEnv env;
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
RpcRendezvousMgr rmgr(&env);
const string key = Rendezvous::CreateKey(
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
{ // Explicit Abort().
const int64 step_id = 123;
Rendezvous* rendez = rmgr.Find(step_id);
core::ScopedUnref unref(rendez);
SchedClosure([env, rendez]() {
env.env->SleepForMicroseconds(100 * 1000);
rendez->StartAbort(errors::Aborted(""));
});
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
}
{ // Cleanup causes Abort().
const int64 step_id = 321;
Rendezvous* rendez = rmgr.Find(step_id);
core::ScopedUnref unref(rendez);
SchedClosure([env, &rmgr, step_id]() {
env.env->SleepForMicroseconds(100 * 1000);
rmgr.Cleanup(step_id);
});
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
}
}
TEST(RpcRendezvousMgrTest, CleanupAll) {
WorkerEnv env;
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
RpcRendezvousMgr rmgr(&env);
const string key = Rendezvous::CreateKey(
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
{
const int64 step_id = 123;
Rendezvous* rendez = rmgr.Find(step_id);
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
rmgr.CleanupAll();
Tensor val(DT_STRING);
bool val_dead = false;
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
}
}
class DummyDeviceContext : public DeviceContext {
public:
explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
~DummyDeviceContext() override {}
int stream_id() const { return stream_id_; }
private:
const int stream_id_;
};
TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
DummyDeviceContext* dc = new DummyDeviceContext(123);
WorkerEnv env;
env.env = Env::Default();
env.worker_name = "/job:mnist/replica:1/task:2";
RpcRendezvousMgr rmgr(&env);
const int64 step_id = 123;
const string key = Rendezvous::CreateKey(
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
{
Rendezvous* rendez = rmgr.Find(step_id);
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
args.device_context = dc;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
}
{
Notification n;
rmgr.RecvLocalAsync(
step_id, key, [&n](const Status& s, const Rendezvous::Args send_args,
const Rendezvous::Args recv_args, const Tensor& val,
bool is_dead) {
auto send_dev_context =
static_cast<DummyDeviceContext*>(send_args.device_context);
CHECK_EQ(123, send_dev_context->stream_id());
CHECK_EQ(V(val), "peach");
n.Notify();
});
n.WaitForNotification();
}
rmgr.Cleanup(step_id);
dc->Unref();
}
// NOTE: Remote Send/Recv is better tested in worker_test.cc
} // namespace tensorflow

View File

@ -0,0 +1,309 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/simple_graph_execution_state.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/simple_placer.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/graph/dot.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
string BuildGraphOptions::DebugString() const {
string rv = "Feed endpoints: ";
for (auto& s : feed_endpoints) {
strings::StrAppend(&rv, s, ", ");
}
strings::StrAppend(&rv, "\nFetch endpoints: ");
for (auto& s : fetch_endpoints) {
strings::StrAppend(&rv, s, ", ");
}
strings::StrAppend(&rv, "\nTarget nodes: ");
for (auto& s : target_nodes) {
strings::StrAppend(&rv, s, ", ");
}
return rv;
}
SimpleGraphExecutionState::SimpleGraphExecutionState(
const OpRegistryInterface* ops,
const SimpleGraphExecutionStateOptions& options)
: ops_(ops),
device_set_(options.device_set),
session_options_(options.session_options),
base_(nullptr),
placed_(nullptr) {
// TODO(mrry): Publish placement visualizations or handle the log
// placement option.
}
SimpleGraphExecutionState::~SimpleGraphExecutionState() {
mutex_lock l(mu_);
delete base_;
delete placed_;
}
Status SimpleGraphExecutionState::Create(GraphDef* graph_def) {
if (original_graph_def_.node_size() > 0) {
return errors::InvalidArgument(
"Cannot call Create on SimpleGraphExecutionState twice");
}
original_graph_def_.Swap(graph_def);
VLOG(2) << "Incoming def: " << original_graph_def_.DebugString();
return AddDefaultAttrsToGraphDef(&original_graph_def_, *ops_, 0);
}
Status SimpleGraphExecutionState::Extend(
const GraphDef& extension_def, SimpleGraphExecutionState** out) const {
std::unordered_set<string> new_names;
// 1. Build an index of the new node names.
for (const NodeDef& node : extension_def.node()) {
new_names.insert(node.name());
}
// 2. Add the non-duplicates from the old graph to the new graph.
// Return an error if the same node name appears in both the
// old graph and the extension.
GraphDef gdef;
for (const NodeDef& node : original_graph_def_.node()) {
if (new_names.count(node.name()) == 0) {
*gdef.add_node() = node;
} else {
return errors::InvalidArgument(tensorflow::strings::Printf(
"GraphDef argument to Extend includes node '%s', which was created "
"by a previous call to Create or Extend in this session.",
node.name().c_str()));
}
}
int old_node_size = gdef.node_size();
gdef.mutable_node()->MergeFrom(extension_def.node());
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *ops_, old_node_size));
// 3. Add the extension.
SimpleGraphExecutionStateOptions combined_options;
combined_options.device_set = device_set_;
SimpleGraphExecutionState* new_execution_state =
new SimpleGraphExecutionState(ops_, combined_options);
Status new_execution_state_status = new_execution_state->Create(&gdef);
if (!new_execution_state_status.ok()) {
delete new_execution_state;
return new_execution_state_status;
}
*out = new_execution_state;
// Ensure that any state created in the precursor is accessible in the
// new graph.
{
mutex_lock l(mu_);
for (const auto& placement : stateful_placements_) {
(*out)->stateful_placements_.insert(placement);
}
}
// TODO(mrry): This is likely to be used for non-throughput-sensitive
// interactive workloads, but in future we may want to transfer other
// parts of the placement and/or cost model.
return Status::OK();
}
Status SimpleGraphExecutionState::InitBaseGraph() {
std::unique_ptr<Graph> new_base(new Graph(ops_));
GraphConstructorOptions opts;
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(opts, original_graph_def_, new_base.get()));
for (const Node* n : new_base->nodes()) {
VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
node_name_to_cost_id_map_[n->name()] = n->cost_id();
}
Status status = PreliminaryPlace(*new_base);
if (!status.ok()) {
node_name_to_cost_id_map_.clear();
return status;
}
base_ = new_base.release();
return Status::OK();
}
Status SimpleGraphExecutionState::GlobalNodeDefByName(const string& name,
NodeDef* out) {
NodeNameToCostIdMap::const_iterator iter =
node_name_to_cost_id_map_.find(name);
if (iter != node_name_to_cost_id_map_.end()) {
mutex_lock l(mu_); // could use reader lock
const Node* node = placed_->FindNodeId(iter->second);
if (node) {
*out = node->def();
return Status::OK();
}
}
return errors::NotFound("Node name: ", name);
}
Status SimpleGraphExecutionState::PreliminaryPlace(const Graph& base) {
VLOG(1) << "PreliminaryPlace";
Graph* ng = new Graph(ops_);
CopyGraph(base, ng);
Status status = DoPlace(ng);
if (!status.ok()) {
delete ng;
} else {
delete placed_;
placed_ = ng;
FreezeStatefulNodes(true /*is_prelim*/);
}
return status;
}
void SimpleGraphExecutionState::FreezeStatefulNodes(bool is_prelim) {
if (is_prelim) {
// During the preliminary placement every stateful Node got placed
// somewhere, and we need to remember where, so it doesn't move.
for (Node* n : placed_->nodes()) {
if (n->op_def().is_stateful()) {
stateful_placements_[n->name()] = n->assigned_device_name();
}
}
} else {
// During later placements it's possible for new stateful nodes to
// appear. They are noticed while we're pinning the pre-existing
// stateful nodes to their prior positions, and after they've been
// placed this function is entered to record their placements.
for (Node* n : missing_stateful_placements_) {
CHECK(n->op_def().is_stateful());
stateful_placements_[n->name()] = n->assigned_device_name();
}
missing_stateful_placements_.clear();
}
}
void SimpleGraphExecutionState::PlaceStatefulNodes(Graph* graph) {
for (Node* n : graph->nodes()) {
if (n->op_def().is_stateful()) {
PlaceMap::const_iterator iter = stateful_placements_.find(n->name());
if (iter == stateful_placements_.end()) {
// NOTE(tucker): I don't understand why this can occur. So far,
// I've only seen it in eval instances, started from a checkpoint.
missing_stateful_placements_.push_back(n);
} else {
n->set_assigned_device_name(iter->second);
}
}
}
}
Status SimpleGraphExecutionState::DoPlace(Graph* graph) {
Status status;
// TODO(mrry): Port other placement algorithms from whitepaper.
return SimplePlacement(graph);
}
Status SimpleGraphExecutionState::BuildGraph(const BuildGraphOptions& options,
ClientGraph** out) {
VLOG(1) << "BuildGraph";
mutex_lock l(mu_);
// Lazily initialize the base graph.
if (base_ == nullptr) {
TF_RETURN_IF_ERROR(InitBaseGraph());
}
if (!base_ || !placed_) {
return ::tensorflow::errors::Internal(
"There was a problem building the graph.");
}
std::unique_ptr<ClientGraph> cgraph(new ClientGraph(ops_));
CopyGraph(*placed_, &cgraph->graph);
// Extract the subset of the graph that needs to be run, adding feed/fetch
// ops as needed.
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
&cgraph->graph, options.feed_endpoints, options.fetch_endpoints,
options.target_nodes, device_set_->client_device()->attributes()));
// Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by
// the largest node id.
{
std::unique_ptr<ClientGraph> dense_copy(new ClientGraph(ops_));
CopyGraph(cgraph->graph, &dense_copy->graph);
cgraph = std::move(dense_copy);
}
// TODO(vrv): We should check invariants of the graph here.
*out = cgraph.release();
return Status::OK();
}
Status SimpleGraphExecutionState::DeviceIsCompatible(
Node* n, const Device* device) const {
if (!n->def().device().empty()) {
DeviceNameUtils::ParsedName pname;
if (!DeviceNameUtils::ParseFullName(n->def().device(), &pname)) {
return AttachDef(
errors::InvalidArgument("Malformed device specification '",
n->def().device(), "'"),
n->def());
}
std::vector<Device*> devices;
device_set_->FindMatchingDevices(pname, &devices);
for (auto d : devices) {
if (d == device) {
return Status::OK();
}
}
return AttachDef(
errors::InvalidArgument(
"Specified device '", n->def().device(),
"' not compatible with device of ref connection: ", device->name()),
n->def());
}
return Status::OK();
}
Status SimpleGraphExecutionState::SimplePlacement(Graph* graph) {
SimplePlacer placer(graph, device_set_, &node_name_to_cost_id_map_,
session_options_);
// TODO(mrry): Consider making the SimplePlacer cancelable.
return placer.Run();
}
} // namespace tensorflow

View File

@ -0,0 +1,156 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/distributed_runtime/build_graph_options.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class SessionOptions;
class StepStats;
class Timeline;
struct SimpleGraphExecutionStateOptions {
const DeviceSet* device_set = nullptr;
const SessionOptions* session_options = nullptr;
};
// A ClientGraph is simply a sub-graph of the full graph as induced by
// BuildGraphOptions.
struct ClientGraph {
Graph graph;
explicit ClientGraph(const OpRegistryInterface* ops) : graph(ops) {}
int32 placement_version;
};
// SimpleGraphExecutionState is responsible for generating an
// executable ClientGraph from the original GraphDef that specifies
// the complete graph and from BuildGraphOptions which specifies
// input/output nodes.
//
// An executable Graph differs from a GraphDef by being Placed,
// meaning that each Node is assigned to a single Device in the
// available set.
//
// When SimpleGraphExecutionState is first constructed it instantiates
// a full Graph from the provided GraphDef, and places it, using only
// the static device assignments from the GraphDef. Nodes without are
// currently placed in a very naive way. Since stateful Nodes cannot
// be moved after initial placement, it is important that stateful
// Nodes get sensible initial device assignments in the graph
// definition.
//
// Subsequently, SimpleGraphExecutionState generates a ClientGraph on
// demand, which is a sub-graph of the latest placement of the full
// Graph. MasterSession uses such a ClientGraph to execute one or
// more similar client requests.
//
// SimpleGraphExecutionState is thread-safe.
class SimpleGraphExecutionState {
public:
SimpleGraphExecutionState(const OpRegistryInterface* ops,
const SimpleGraphExecutionStateOptions& options);
virtual ~SimpleGraphExecutionState();
// Initializes the SimpleGraphExecutionState with 'graph_def'. Can only be
// called once on an original SimpleGraphExecutionState. Callee may modify
// 'graph_def'.
Status Create(GraphDef* graph_def);
// Creates a new SimpleGraphExecutionState representing the
// concatenation of this graph, and the graph defined by
// "extension_def". The same name may not be used to define a node
// in both this graph and "extension_def".
//
// If successful, returns OK and the caller takes ownership of "*out".
// Otherwise returns an error and does not modify "*out".
//
// NOTE(mrry): This method respects the placement of stateful nodes in
// in *this, but currently does not transfer any other placement
// or cost model information to the new graph.
Status Extend(const GraphDef& extension_def,
SimpleGraphExecutionState** out) const;
// Builds a ClientGraph (a sub-graph of the full graph as induced by
// the Node set specified in "options"). If successful, returns OK
// and the caller takes the ownership of "*out". Otherwise, returns
// an error.
Status BuildGraph(const BuildGraphOptions& options, ClientGraph** out);
// Returns OK if the named node is found in the placed full graph owned
// by this execution_state, and sets *out to the NodeDef for that node.
// It may not exist if name is of a Node added for a particular subgraph
// execution, e.g. a send, recv or feed node.
Status GlobalNodeDefByName(const string& name, NodeDef* out);
private:
mutable mutex mu_;
Status InitBaseGraph() EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status PreliminaryPlace(const Graph& graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
void FreezeStatefulNodes(bool is_prelim) EXCLUSIVE_LOCKS_REQUIRED(mu_);
void PlaceStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status DoPlace(Graph* graph);
Status SimplePlacement(Graph* graph);
// Return an OK status if "n" can be assigned to "device".
Status DeviceIsCompatible(Node* n, const Device* device) const;
const OpRegistryInterface* const ops_; // Not owned
GraphDef original_graph_def_; // Immutable after ctor.
const DeviceSet* device_set_; // Not owned
const SessionOptions* session_options_; // Not owned
// Original graph before we make any placement decisions.
Graph* base_ GUARDED_BY(mu_);
// Full graph, placed on the complete set of devices, as a whole.
Graph* placed_ GUARDED_BY(mu_);
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
// is true, such as "params" and "queue" nodes. Once placed these
// nodes can not be moved to a different device. Maps node names to
// device names.
typedef std::unordered_map<string, string> PlaceMap;
PlaceMap stateful_placements_ GUARDED_BY(mu_);
std::vector<Node*> missing_stateful_placements_ GUARDED_BY(mu_);
// Map from name to Node for the full graph in placed_.
NodeNameToCostIdMap node_name_to_cost_id_map_;
TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_

View File

@ -0,0 +1,75 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_
#include <string>
#include <vector>
#include "tensorflow/core/distributed_runtime/worker_interface.h" // for CallOptions
#include "tensorflow/core/framework/device_attributes.pb.h" // for BusAdjacency
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
typedef std::function<void(const Status&)> StatusCallback;
class ChannelCache;
class StepStats;
class WorkerInterface;
class WorkerCacheInterface {
public:
virtual ~WorkerCacheInterface() {}
// Updates *workers with strings naming the remote worker tasks to
// which open channels have been established.
virtual void ListWorkers(std::vector<string>* workers) = 0;
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a new WorkerInterface object
// wrapping that channel. Ownership passes to the caller.
// TODO(tucker): rename this to CreateWorker() or something that
// makes it more obvious this is a constructor that transfers
// ownership, not a cache lookup.
virtual WorkerInterface* CreateWorker(const string& target) = 0;
// Set *ba with the BusAdjacency of the specified remote device
// within its local environment. Returns true if the device bus
// affinity was set, using only locally cached data. Returns false
// if status data for that device was not available. Never blocks.
// TODO(mrry,tucker): Maybe remove.
virtual bool GetDeviceBusNonBlocking(const string& device,
BusAdjacency* ba) = 0;
// Set *ba with the BusAdjacency of the specified remote device
// within its local environment. Callback gets Status::OK if the
// device bus affinity was set.
// TODO(mrry,tucker): Maybe remove.
virtual void GetDeviceBusAsync(const string& device, BusAdjacency* ba,
StatusCallback done) = 0;
// Start/stop logging activity.
virtual void SetLogging(bool active) {}
// Discard any saved log data.
virtual void ClearLogs() {}
// Return logs for the identified step in *ss. Any returned data will no
// longer be stored.
virtual bool RetrieveLogs(int64 step_id, StepStats* ss) { return false; }
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_

View File

@ -0,0 +1,110 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// Maximum number of step_ids for which RPC logs can be maintained.
// TODO(mrry): Make this configurable if necessary.
const int32 kWorkerCacheLoggerLimit = 1 << 10;
} // namespace
void WorkerCacheLogger::SetLogging(bool v) {
mutex_lock l(count_mu_);
if (v) {
++want_logging_count_;
} else {
--want_logging_count_;
// If RPCs get cancelled, it may be possible for the count
// to go negative. This should not be a fatal error, since
// logging is non-critical.
if (want_logging_count_ < 0) want_logging_count_ = 0;
}
}
void WorkerCacheLogger::ClearLogs() {
mutex_lock l(mu_);
ClearLogsWithLock();
}
void WorkerCacheLogger::ClearLogsWithLock() {
for (auto& iter : log_map_) {
delete iter.second.collector;
}
log_map_.clear();
}
bool WorkerCacheLogger::RetrieveLogs(int64 step_id, StepStats* ss) {
mutex_lock l(mu_);
LogMap::iterator iter = log_map_.find(step_id);
if (iter != log_map_.end()) {
iter->second.collector->Swap(ss);
delete iter->second.collector;
log_map_.erase(iter);
return true;
}
return false;
}
void WorkerCacheLogger::Save(const string& device, int64 step_id,
NodeExecStats* ns) {
mutex_lock l(mu_);
StepLog* sl = &log_map_[step_id];
if (!sl->collector) {
sl->collector = new StepStatsCollector(&sl->step_stats);
}
sl->collector->Save(device, ns);
if (log_map_.size() > kWorkerCacheLoggerLimit) {
// Something's gone wrong. Just empty the cache.
ClearLogsWithLock();
}
}
void WorkerCacheLogger::RecordRecvTensor(int64 step_id, int64 start_usecs,
int64 end_usecs,
const string& tensor_name,
const string& src_device,
const string& dst_device,
int64 bytes) {
NodeExecStats* ns = new NodeExecStats;
ns->set_node_name("RecvTensor");
string byte_string = strings::StrCat("[", bytes, "B] ");
if (bytes >= 0.1 * 1048576.0) {
byte_string = strings::Printf("[%.1fMB] ", bytes / 1048576.0);
}
ns->set_timeline_label(strings::StrCat(byte_string, tensor_name, " from ",
src_device, " to ", dst_device));
ns->set_all_start_micros(start_usecs);
ns->set_op_start_rel_micros(0);
ns->set_op_end_rel_micros(end_usecs - start_usecs);
NodeOutput* no = ns->add_output();
no->set_slot(0);
// TODO(tucker): Maybe set the dimensions too, but then they'll
// need to be passed in.
no->mutable_tensor_description()
->mutable_allocation_description()
->set_requested_bytes(bytes);
Save(dst_device, step_id, ns);
}
} // namespace tensorflow

View File

@ -0,0 +1,81 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_
#include <string>
#include <unordered_map>
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class StepStatsCollector;
// WorkerCacheLogger is a thread-safe utility for use by a WorkerCache
// to optionally log some selected RPC activity. A single instance
// should be owned by a WorkerCache, for use by its RemoteWorker
// instances.
class WorkerCacheLogger {
public:
// Start/Stop logging activity. This function increments/decrements
// a counter so that if two separate steps turn logging on/off,
// logging should be on for the union of the durations of both,
// regardless of relative timing.
void SetLogging(bool v);
// Discard any saved log data.
void ClearLogs();
// Return logs for the identified step in *ss. Any returned data will no
// longer be stored. Returns true iff *ss was modified.
bool RetrieveLogs(int64 step_id, StepStats* ss);
// Return true if there is any outstanding request for logging on
// the RPC channels.
bool LoggingActive() {
mutex_lock l(count_mu_);
return want_logging_count_ > 0;
}
// Generates a NodeExecStats record with the given data, and saves for
// later retrieval by RetrieveLogs().
void RecordRecvTensor(int64 step_id, int64 start_usecs, int64 end_usecs,
const string& tensor_name, const string& src_device,
const string& dst_device, int64 bytes);
private:
mutex count_mu_;
int32 want_logging_count_ GUARDED_BY(count_mu_);
struct StepLog {
StepStats step_stats;
StepStatsCollector* collector;
};
typedef std::unordered_map<int64, StepLog> LogMap;
mutex mu_;
LogMap log_map_ GUARDED_BY(mu_);
// Records "ns" in log_map_ under the given device and step.
void Save(const string& device, int64 step_id, NodeExecStats* ns);
void ClearLogsWithLock() EXCLUSIVE_LOCKS_REQUIRED(mu_);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_

View File

@ -0,0 +1,98 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
#include "tensorflow/core/distributed_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
bool WorkerCachePartial::GetDeviceBusNonBlocking(const string& device_name,
BusAdjacency* ba) {
mutex_lock lock(mu_); // could use reader lock
const auto& iter = device_status_cache_.find(device_name);
if (iter != device_status_cache_.end()) {
*ba = iter->second.bus_adjacency();
return true;
}
return false;
}
void WorkerCachePartial::GetDeviceBusAsync(const string& device_name,
BusAdjacency* ba,
StatusCallback done) {
if (!GetDeviceBusNonBlocking(device_name, ba)) {
// If cache entry was empty, make one try to fill it by RPC.
SchedClosure([this, &device_name, ba, done]() {
Status s = RefreshDeviceStatus(device_name);
if (s.ok()) {
if (!GetDeviceBusNonBlocking(device_name, ba)) {
mutex_lock lock(mu_);
const auto& iter = device_status_cache_.find(device_name);
if (iter == device_status_cache_.end()) {
s = errors::Unavailable("No known remote device: ", device_name);
} else {
s = errors::Internal("Failed to find bus_adjacency for ",
device_name);
}
}
}
done(s);
});
return;
}
done(Status::OK());
}
Status WorkerCachePartial::RefreshDeviceStatus(const string& device_name) {
string task;
string device;
Status s;
if (!DeviceNameUtils::SplitDeviceName(device_name, &task, &device)) {
s = errors::InvalidArgument("Bad device name to RefreshDeviceStatus: ",
device_name);
}
std::unique_ptr<WorkerInterface> rwi(CreateWorker(task));
if (s.ok() && !rwi.get()) {
s = errors::Internal("RefreshDeviceStatus, unknown worker task: ", task);
}
if (s.ok()) {
GetStatusRequest req;
GetStatusResponse resp;
s = rwi->GetStatus(&req, &resp);
if (s.ok()) {
mutex_lock lock(mu_);
for (auto& dev_attr : resp.device_attributes()) {
device_status_cache_[dev_attr.name()] = dev_attr;
}
}
}
return s;
}
void WorkerCachePartial::FlushStatusCache() {
mutex_lock lock(mu_);
device_status_cache_.clear();
}
} // namespace tensorflow

View File

@ -0,0 +1,56 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_
#include <string>
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
// Implements the part of the interface that caches and returns remote
// device status attributes.
class WorkerCachePartial : public WorkerCacheInterface {
public:
bool GetDeviceBusNonBlocking(const string& device, BusAdjacency* ba) override;
void GetDeviceBusAsync(const string& device, BusAdjacency* ba,
StatusCallback) override;
~WorkerCachePartial() override {}
// Clear all entries from the DeviceStatus cache.
void FlushStatusCache();
private:
mutex mu_;
// Initiate a GetStatusAsync to the remote task named by "task", and
// update the cache with all the DeviceAttributes reported.
Status RefreshDeviceStatus(const string& device_name);
typedef std::unordered_map<string, DeviceAttributes> StatusMap;
StatusMap device_status_cache_ GUARDED_BY(mu_);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_

View File

@ -0,0 +1,62 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace thread {
class ThreadPool;
} // namespace thread
class DeviceMgr;
class Env;
class GraphMgr;
class RendezvousMgrInterface;
class WorkerCacheInterface;
// The worker environment class, which holds a bag of pointers to
// per-worker singletons.
//
// WorkerEnv does not own its member pointers.
struct WorkerEnv {
Env* env = nullptr;
// The name of the worker. E.g., /job:mnist/replica:1/task:0.
string worker_name;
// Object from which WorkerInterface instances can be obtained.
WorkerCacheInterface* worker_cache = nullptr;
// device_mgr manages local devices (cpu and gpu). The WorkerService
// is the network interface for managed devices.
DeviceMgr* device_mgr = nullptr;
// graph_mgr keeps track of registered graphs of this worker.
GraphMgr* graph_mgr = nullptr;
// A set of rendezvous keyed by step ids.
RendezvousMgrInterface* rendezvous_mgr = nullptr;
// A pool of threads for scheduling compute work.
thread::ThreadPool* compute_pool = nullptr;
};
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_

View File

@ -0,0 +1,129 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
#include <functional>
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
// Status callback.
typedef std::function<void(const Status&)> StatusCallback;
// Allocator callback for out-of-band transfers.
class TensorShape;
typedef std::function<void*(size_t, const DataType&, const TensorShape&)>
TensorBufAllocator;
// Interface for talking with the TensorFlow Worker service.
class WorkerInterface {
public:
virtual ~WorkerInterface() {}
virtual void GetStatusAsync(const GetStatusRequest* request,
GetStatusResponse* response,
StatusCallback done) = 0;
virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
RegisterGraphResponse* response,
StatusCallback done) = 0;
virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response,
StatusCallback done) = 0;
virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
RunGraphResponse* response,
StatusCallback done) = 0;
virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) = 0;
virtual void CleanupAllAsync(const CleanupAllRequest* request,
CleanupAllResponse* response,
StatusCallback done) = 0;
virtual void RecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
RecvTensorResponse* response,
TensorBufAllocator allocator,
StatusCallback done) = 0;
virtual void LoggingAsync(const LoggingRequest* request,
LoggingResponse* response, StatusCallback done) = 0;
virtual void TracingAsync(const TracingRequest* request,
TracingResponse* response, StatusCallback done) = 0;
Status GetStatus(const GetStatusRequest* request,
GetStatusResponse* response) {
return CallAndWait(&ME::GetStatusAsync, request, response);
}
Status RegisterGraph(const RegisterGraphRequest* request,
RegisterGraphResponse* response) {
return CallAndWait(&ME::RegisterGraphAsync, request, response);
}
Status DeregisterGraph(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response) {
return CallAndWait(&ME::DeregisterGraphAsync, request, response);
}
Status CleanupGraph(const CleanupGraphRequest* request,
CleanupGraphResponse* response) {
return CallAndWait(&ME::CleanupGraphAsync, request, response);
}
Status CleanupAll(const CleanupAllRequest* request,
CleanupAllResponse* response) {
return CallAndWait(&ME::CleanupAllAsync, request, response);
}
Status Logging(const LoggingRequest* request, LoggingResponse* response) {
return CallAndWait(&ME::LoggingAsync, request, response);
}
Status Tracing(const TracingRequest* request, TracingResponse* response) {
return CallAndWait(&ME::TracingAsync, request, response);
}
private:
typedef WorkerInterface ME;
template <typename Method, typename Req, typename Resp>
Status CallAndWait(Method func, const Req* req, Resp* resp) {
Status ret;
Notification n;
(this->*func)(req, resp, [&ret, &n](const Status& s) {
ret = s;
n.Notify();
});
n.WaitForNotification();
return ret;
}
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_

View File

@ -65,7 +65,7 @@ Status LoadLibrary(const char* library_filename, void** result,
string str;
GetOpList(&str);
char* str_buf = reinterpret_cast<char*>(operator new(str.length()));
strncpy(str_buf, str.data(), str.length());
memcpy(str_buf, str.data(), str.length());
*buf = str_buf;
*len = str.length();

View File

@ -25,37 +25,58 @@ def tf_deps(deps, suffix):
return tf_deps
def tf_proto_library(name, srcs = [], has_services = False,
deps = [], visibility = [], testonly = 0,
cc_api_version = 2, go_api_version = 2,
java_api_version = 2,
py_api_version = 2):
def tf_proto_library_cc(name, srcs = [], has_services = None,
deps = [], visibility = [], testonly = 0,
cc_libs = [],
cc_stubby_versions = None,
cc_grpc_version = None,
cc_api_version = 2, go_api_version = 2,
java_api_version = 2,
py_api_version = 2):
native.filegroup(name=name + "_proto_srcs",
srcs=srcs + tf_deps(deps, "_proto_srcs"),
testonly=testonly,)
use_grpc_plugin = None
if cc_grpc_version:
use_grpc_plugin = True
cc_proto_library(name=name + "_cc",
srcs=srcs + tf_deps(deps, "_proto_srcs"),
deps=deps + ["//google/protobuf:cc_wkt_protos"],
cc_libs = ["//google/protobuf:protobuf"],
cc_libs = cc_libs + ["//google/protobuf:protobuf"],
use_grpc_plugin = use_grpc_plugin,
testonly=testonly,
visibility=visibility,)
py_proto_library(name=name + "_py",
srcs=srcs + tf_deps(deps, "_proto_srcs"),
srcs_version="PY2AND3",
deps=deps + ["//google/protobuf:protobuf_python"],
testonly=testonly,
visibility=visibility,)
def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0):
def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0,
srcs_version="PY2AND3"):
py_proto_library(name = name + "_py",
srcs = srcs,
srcs_version = "PY2AND3",
srcs_version = srcs_version,
deps = deps,
visibility = visibility,
testonly = testonly)
def tf_proto_library(name, srcs = [], has_services = None,
deps = [], visibility = [], testonly = 0,
cc_libs = [],
cc_api_version = 2, go_api_version = 2,
java_api_version = 2,
py_api_version = 2):
tf_proto_library_cc(name=name,
srcs=srcs + tf_deps(deps, "_proto_srcs"),
deps=deps,
cc_libs=cc_libs,
testonly=testonly,
visibility=visibility,)
tf_proto_library_py(name=name,
srcs=srcs + tf_deps(deps, "_proto_srcs"),
srcs_version="PY2AND3",
deps=deps + ["//google/protobuf:protobuf_python"],
testonly=testonly,
visibility=visibility,)
def tf_additional_lib_srcs():
return [
"platform/default/*.h",

View File

@ -0,0 +1,190 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow;
// option cc_enable_arenas = true;
option java_outer_classname = "DistributedRuntimeProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.distruntime";
import "tensorflow/core/framework/config.proto";
import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/tensor.proto";
////////////////////////////////////////////////////////////////////////////////
//
// CreateSession method request/response protos.
//
////////////////////////////////////////////////////////////////////////////////
message CreateSessionRequest {
// The initial graph definition.
GraphDef graph_def = 1;
// Configuration options.
ConfigProto config = 2;
}
message CreateSessionResponse {
// The session handle to be used in subsequent calls for the created session.
//
// The client must arrange to call CloseSession with this returned
// session handle to close the session.
string session_handle = 1;
// The initial version number for the graph, to be used in the next call
// to ExtendSession.
int64 graph_version = 2;
}
////////////////////////////////////////////////////////////////////////////////
//
// ExtendSession method request/response protos.
//
// The "graph_def" specifies a set of nodes to be added to the session's graph.
//
// A typical "graph_def" will contain:
//
// * Zero or more new nodes with names that do not exist in the server-side
// graph. These will be added to the graph.
//
// PRECONDITION: The server-side current version is req.current_version.
// None of the names in req.graph_def appeared in previous successful calls to
// CreateSession or ExtendSession with the same session_handle.
// POSTCONDITION: The server-side current version is resp.new_version.
//
////////////////////////////////////////////////////////////////////////////////
message ExtendSessionRequest {
// REQUIRED: session_handle must be returned by a CreateSession call
// to the same master service.
string session_handle = 1;
// REQUIRED: The nodes to be added to the session's graph. If any node has
// the same name as an existing node, the operation will fail with
// ILLEGAL_ARGUMENT.
GraphDef graph_def = 2;
// REQUIRED: The version number of the graph to be extended. This will be
// tested against the current server-side version number, and the operation
// will fail with FAILED_PRECONDITION if they do not match.
int64 current_graph_version = 3;
}
message ExtendSessionResponse {
// TODO(mrry): Return something about the operation?
// The new version number for the extended graph, to be used in the next call
// to ExtendSession.
int64 new_graph_version = 4;
}
////////////////////////////////////////////////////////////////////////////////
//
// RunStep method request/response protos.
//
// The caller should provide the feeds needed by the graph and specify
// what nodes should be fetched.
//
////////////////////////////////////////////////////////////////////////////////
// A pair of tensor name and tensor values.
message NamedTensorProto {
// Name of the tensor.
string name = 1;
// The client can populate a TensorProto using a tensorflow::Tensor`, or
// directly using the protobuf field accessors.
//
// The client specifies whether the returned tensor values should be
// filled tensor fields (float_val, int_val, etc.) or encoded in a
// compact form in tensor.tensor_content.
TensorProto tensor = 2;
}
message RunStepRequest {
// REQUIRED: session_handle must be returned by a CreateSession call
// to the same master service.
string session_handle = 1;
// Tensors to be fed in the step. Each feed is a named tensor.
repeated NamedTensorProto feed = 2;
// Fetches. A list of tensor names. The caller expects a tensor to
// be returned for each fetch[i] (see RunStepResponse.tensor). The
// order of specified fetches does not change the execution order.
repeated string fetch = 3;
// Target Nodes. A list of node names. The named nodes will be run
// to but their outputs will not be fetched.
repeated string target = 4;
}
message RunStepResponse {
// NOTE: The order of the returned tensors may or may not match
// the fetch order specified in RunStepRequest.
repeated NamedTensorProto tensor = 1;
// TODO(mrry): Optionally aggregate StepStats in some form here.
}
////////////////////////////////////////////////////////////////////////////////
//
// CloseSession method request/response protos.
//
////////////////////////////////////////////////////////////////////////////////
message CloseSessionRequest {
// REQUIRED: session_handle must be returned by a CreateSession call
// to the same master service.
string session_handle = 1;
}
message CloseSessionResponse {
}
message ResetRequest {
// A list of container names, which may be empty.
//
// If 'container' is not empty, releases resoures in the given
// containers in all devices.
//
// If 'container' is empty, releases resources in the default
// container in all devices.
repeated string container = 1;
}
message ResetResponse {
}
////////////////////////////////////////////////////////////////////////////////
//
// ListDevices method request/response protos.
//
// Returns information about the TensorFlow devices that are available
// to this master.
//
////////////////////////////////////////////////////////////////////////////////
message ListDevicesRequest {
}
message ListDevicesResponse {
repeated DeviceAttributes local_device = 1;
repeated DeviceAttributes remote_device = 2;
}

View File

@ -0,0 +1,105 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow.grpc;
option java_outer_classname = "MasterServiceProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.distruntime";
import "tensorflow/core/protobuf/master.proto";
////////////////////////////////////////////////////////////////////////////////
//
// MasterService defines a TensorFlow service with which a client can
// interact to execute a distributed TensorFlow computation.
//
// A master service keeps track of multiple "master sessions". Each
// session encapsulates a computation graph and its associated state,
// and typically corresponds to a single "client session" (e.g. a
// `tensorflow::Session` instance).
//
// A session is responsible for the following:
// * assigning each node to a device (locally or remotely) using a
// placement algorithm. This may make decisions based on collected
// statistics from the workers in the system (e.g., memory usage,
// bandwidth consumption, etc.)
//
// * inserting intermediate nodes and edges to support cross-device
// and cross-process data flows and resource management.
//
// * issuing commands to workers to execute the subgraphs associated
// with those workers.
//
// Typically, a client carries out an iterative computation
// (e.g. training) by invoking RPCs against the master in a
// client-side loop. The client first creates a client session that
// connects to a particular master (using gRPC for example). The
// master creates a corresponding master session that is hosted on
// the master and caches state between the client's invocations.
//
// After the session is established, the master returns an opaque
// handle to the client that can be used to associate the client and
// master sessions.
//
// The client may send an initial graph to the master in the
// CreateSession call, and add nodes to the graph using ExtendSession.
//
// The most frequent operation a master is "RunStep", which implements
// the `Session::Run()` API. It supports feeding in arguments,
// executing a dataflow computation, and fetching arguments.
//
// Finally, when the client no longer needs the session, it should
// close the session by invoking CloseSession, which allows the master
// to reclaim resources associated with the session. The master may
// implement a garbage collection scheme that closes sessions that
// have been inactive for some time.
//
// For example, the following pseudo-code illustrates how a client
// interacts with a master:
//
// stub = NewStub("/job:mnist/replica:0/task:0")
// {handle} = stub->CreateSession({graph_def})
// do {
// stub->RunStep({handle, {feeds}, {fetches}})
// // The client can evaluate a predicate locally, based on the
// // result of `fetches`, to determine whether to terminate. For
// // example, it might fetch the loss and evaluate whether it is less
// // than some threshold.
// } whlie (!should_stop({fetches}));
// stub->CloseSession({handle})
//
////////////////////////////////////////////////////////////////////////////////
service MasterService {
// Creates a session.
rpc CreateSession(CreateSessionRequest) returns (CreateSessionResponse);
// Extends a session.
rpc ExtendSession(ExtendSessionRequest) returns (ExtendSessionResponse);
// Drives the graph computation.
rpc RunStep(RunStepRequest) returns (RunStepResponse);
// Closes a session.
rpc CloseSession(CloseSessionRequest) returns (CloseSessionResponse);
// List the devices usable by the master.
rpc ListDevices(ListDevicesRequest) returns (ListDevicesResponse);
// Close all existing sessions.
rpc Reset(ResetRequest) returns (ResetResponse);
}

View File

@ -0,0 +1,311 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow;
// option cc_enable_arenas = true;
option java_outer_classname = "WorkerProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.distruntime";
import "google/protobuf/any.proto";
import "tensorflow/core/framework/config.proto";
import "tensorflow/core/framework/step_stats.proto";
import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/tensor.proto";
////////////////////////////////////////////////////////////////////////////////
//
// GetStatus method request/response messages
//
////////////////////////////////////////////////////////////////////////////////
message GetStatusRequest {
}
message GetStatusResponse {
repeated DeviceAttributes device_attributes = 1;
}
////////////////////////////////////////////////////////////////////////////////
//
// RegisterGraph method request/response messages
//
// For each session, after the master placed every node on a device,
// it partitions the whole graph into many subgraphs. All the nodes in
// a subgraph were in the same worker, but potentially on many devices
// owned by that worker (e.g. cpu0, plus gpu0, gpu1, ..., gpu7). The
// master registers subgraphs for a worker before running any steps. A
// successful registration returns a graph handle to be used in latter
// RunGraph requests.
//
////////////////////////////////////////////////////////////////////////////////
message RegisterGraphRequest {
// Subgraphs are scoped within one session.
string session_handle = 1;
// "graph_def" has the subgraph of nodes for this worker, with each node
// having its device_name filled in.
GraphDef graph_def = 2;
// True iff the graph (before partitioning) contains control flow nodes.
//
// As of 01/11/2015, this is no longer set by clients.
bool has_control_flow = 3 [deprecated = true];
// Configuration options for the session in which this graph was created.
GraphOptions graph_options = 4;
}
message RegisterGraphResponse {
// If the registration succeeds, returns an opaque graph_handle to
// the master. The master calls RunGraph with graph_handle to
// compute different steps.
string graph_handle = 1;
}
////////////////////////////////////////////////////////////////////////////////
//
// DeregisterGraph method request/response messages
//
// The master deregisters the given graph_handle when the graph is no
// longer needed (e.g., the overall graph is re-scheduled and nodes
// are re-placed).
//
// The worker deregisters a graph_handle automatically according to on
// a TTL-base policy in case of master restarts.
//
////////////////////////////////////////////////////////////////////////////////
message DeregisterGraphRequest {
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
string graph_handle = 1;
}
message DeregisterGraphResponse {
// TODO(mrry): Optionally add summary stats for the graph.
}
////////////////////////////////////////////////////////////////////////////////
//
// CleanupAll method request/response messages
//
////////////////////////////////////////////////////////////////////////////////
message CleanupAllRequest {
// A list of container names.
//
// If 'container' is not empty, releases resoures in the given
// containers in all devices.
//
// If 'container' is empty, releases resources in the default
// container in all devices.
repeated string container = 1;
}
message CleanupAllResponse {
}
////////////////////////////////////////////////////////////////////////////////
//
// RunGraph request / response messages
//
// The worker executes all subgraphs registered under graph_handle.
// RunGraph returns after the execution finishes or an error is
// encountered.
//
////////////////////////////////////////////////////////////////////////////////
// A pair of tensor name and tensor values.
message NamedTensor {
// The name of the named tensor.
string key = 1;
// The value of the named tensor.
TensorProto val = 2;
}
// Options specific to the execution of a single step.
message ExecutorOpts {
bool record_costs = 1;
bool record_timeline = 3;
};
message RunGraphRequest {
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
string graph_handle = 1;
// A unique ID to distinguish different runs of the same graph.
//
// The master generates a global unique `step_id` to dinstinguish
// different runs of the graph computation. Subgraphs communicate
// (e.g., send/recv ops) with each other using `step_id` to
// distinguish tensors generated by different runs.
int64 step_id = 2;
// Options for this step.
ExecutorOpts exec_opts = 5;
// Runs the graph.
//
// Sends the tensors in "send" into the graph before the run and
// fetches the keys into `RunGraphResponse.recv` after the run.
repeated NamedTensor send = 3;
repeated string recv_key = 4;
}
message RunGraphResponse {
// A list of tensors corresponding to those requested by
// `RunGraphRequest.recv_key`.
repeated NamedTensor recv = 1;
// If the request asked for execution stats, these are returned here.
StepStats step_stats = 2;
}
////////////////////////////////////////////////////////////////////////////////
//
// CleanupGraph method request/response messages
//
// After the master receives RunGraph responses from all workers, the
// master instructs every worker to cleanup any remaining state of a
// step (e.g. tensors buffered by a `Send` op but not picked up by
// other workers). The master does not necessarily need to wait for
// completion of CleanupGraph calls.
//
// Workers should cleanup step states automatically according to a
// TTL-based policy in case of master restarts.
//
////////////////////////////////////////////////////////////////////////////////
message CleanupGraphRequest {
int64 step_id = 1;
}
message CleanupGraphResponse {
}
////////////////////////////////////////////////////////////////////////////////
//
// RecvTensor method request/response messages
//
////////////////////////////////////////////////////////////////////////////////
message RecvTensorRequest {
// The step in which the tensor will be produced.
//
// REQUIRED: This must eventually correspond to the `step_id` passed
// into a RunGraph call on the same WorkerService.
int64 step_id = 1;
// A key that identifies the tensor to be received.
string rendezvous_key = 2;
// If true, use an out-of-band DMA mechanism to transfer the
// received tensor.
bool dma_ok = 3;
// NIC bus preference on the request originator side
BusAdjacency client_bus_adjacency = 4;
// NIC bus preference on the request receiver side
BusAdjacency server_bus_adjacency = 5;
}
message RecvTensorResponse {
// The tensor as a proto.
TensorProto tensor = 1;
// If true, this tensor was the output of a dead node, and the
// content is invalid.
bool is_dead = 2;
// The time at which tensor was available and started to be returned.
int64 send_start_micros = 3;
// Optional additional information about how to receive the tensor,
// in the event that `RecvTensorRequest.dma_ok` was true.
google.protobuf.Any transport_options = 4;
}
////////////////////////////////////////////////////////////////////////////////
//
// Logging method request/response messages
//
// NOTE(mrry): This feature is not supported in the open-source
// version, and these messages are expected to change.
//
////////////////////////////////////////////////////////////////////////////////
// Out-of-band request to begin or end logging, or
// to retrieve logs for particular steps.
message LoggingRequest {
// If true, RPC logging will be activated.
bool rpc_logging = 1;
// If true, discard any saved logging data (for all steps).
bool clear = 2;
// When set, requests all saved log data pertaining to the step.
// Any log data retrieved is eliminated from the store and cannot be
// retrieved again.
repeated int64 fetch_step_id = 3;
}
message LabeledStepStats {
int64 step_id = 1;
StepStats step_stats = 2;
}
message LoggingResponse {
repeated LabeledStepStats step = 1;
}
////////////////////////////////////////////////////////////////////////////////
//
// Tracing method request/response messages
//
// NOTE(mrry): This feature is not supported in the open-source
// version, and these messages are expected to change.
//
////////////////////////////////////////////////////////////////////////////////
message TraceOpts {
// Length of the trace to be taken, in seconds.
double duration = 1;
// If true, capture step profile locally in each worker. Currently
// unimplemented.
bool use_step_profiler = 2;
// If true, capture kernel events from each worker.
bool use_kernel_profiler = 3;
// If true, capture extended profiling events from TensorFlow process.
bool use_extended_profiler = 4;
// If true, capture GPU profiling events locally on each
// machine. Currently unimplemented.
bool use_gpu_profiler = 5;
// If true, collect sampled profile events. Currently unimplemented.
bool use_sample_profiler = 6;
}
// Out-of-band request to configure distributed tracing.
message TracingRequest {
TraceOpts options = 1;
}
message TracingResponse {
}

View File

@ -0,0 +1,67 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow.grpc;
option java_outer_classname = "WorkerServiceProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.distruntime";
import "tensorflow/core/protobuf/worker.proto";
////////////////////////////////////////////////////////////////////////////////
//
// WorkerService defines a TensorFlow service that executes dataflow
// graphs on a set of local devices, on behalf of a MasterService.
//
// A worker service keeps track of multiple "registered graphs". Each
// registered graph is a subgraph of a client's graph, corresponding to
// only the nodes that should execute on this worker (and any
// additional nodes necessary for inter-process communication using
// the `RecvTensor` method).
//
////////////////////////////////////////////////////////////////////////////////
service WorkerService {
// See worker.proto for details.
rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);
// See worker.proto for details.
rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse);
// See worker.proto for details.
rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse);
// See worker.proto for details.
rpc RunGraph(RunGraphRequest) returns (RunGraphResponse);
// See worker.proto for details.
rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse);
// See worker.proto for details.
rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse);
// See worker.proto for details.
rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {
// RecvTensor Method
}
// See worker.proto for details.
rpc Logging(LoggingRequest) returns (LoggingResponse);
// See worker.proto for details.
rpc Tracing(TracingRequest) returns (TracingResponse);
}

View File

@ -920,6 +920,7 @@ tf_py_wrap_cc(
":py_record_writer_lib",
":python_op_gen",
":tf_session_helper",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//util/python:python_headers",
],
)

View File

@ -750,8 +750,9 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(c_list[1], out[1].decode('utf-8'))
def testInvalidTargetFails(self):
with self.assertRaisesRegexp(RuntimeError,
'Registered factories are {DIRECT_SESSION}'):
with self.assertRaisesRegexp(
RuntimeError,
'No session factory registered for the given session options.'):
session.Session('INVALID_TARGET')
def testFetchByNameDifferentStringTypes(self):

View File

@ -135,7 +135,7 @@ have varying scale, and to aid generalization.
@@l2_normalize
@@local_response_normalization
@@sufficient_statistics
@@aggregate_moments
@@normalize_moments
@@moments
## Losses
@ -561,7 +561,7 @@ def sufficient_statistics(x, axes, shift=True, keep_dims=False, name=None):
return counts, m_ss, v_ss, shift_value
def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None):
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
"""Calculate the mean and variance of based on the sufficient statistics.
Args:
@ -577,7 +577,7 @@ def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None):
Returns:
Two `Tensor` objects: `mean` and `variance`.
"""
with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "aggregate"):
with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "normalize"):
divisor = math_ops.inv(counts, name="divisor")
if shift is not None:
shifted_mean = math_ops.mul(mean_ss, divisor, name="shifted_mean")
@ -620,7 +620,7 @@ def moments(x, axes, name=None, keep_dims=False):
axes,
keep_dims=keep_dims,
name=name)
return aggregate_moments(counts, m_ss, v_ss, shift, name=name)
return normalize_moments(counts, m_ss, v_ss, shift, name=name)
def batch_normalization(x,

View File

@ -826,19 +826,19 @@ class SufficientStatisticsTest(tf.test.TestCase):
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
class AggregateMomentsTest(tf.test.TestCase):
class NormalizeMomentsTest(tf.test.TestCase):
def _npAggregateMoments(self, counts, mean_ss, variance_ss, shift):
def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
mean = mean_ss / counts
variance = variance_ss / counts - mean * mean
if shift is not None:
mean += shift
return mean, variance
def _opAggregateMoments(self, counts, mean_ss, variance_ss, shift):
return tf.nn.aggregate_moments(counts, mean_ss, variance_ss, shift)
def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
return tf.nn.normalize_moments(counts, mean_ss, variance_ss, shift)
def _testAggregateMoments(self, shape, shift):
def _testNormalizeMoments(self, shape, shift):
counts = np.ones([1]).astype(np.float32)
mean_ss = np.random.random_sample(shape).astype(np.float32)
variance_ss = np.random.random_sample(shape).astype(np.float32)
@ -847,7 +847,7 @@ class AggregateMomentsTest(tf.test.TestCase):
shift_v = np.random.random_sample(shape).astype(np.float32)
else:
shift_v = None
npm, npv = self._npAggregateMoments(counts, mean_ss, variance_ss, shift_v)
npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v)
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu) as sess:
tf_counts = tf.constant(counts, name="counts")
@ -857,16 +857,16 @@ class AggregateMomentsTest(tf.test.TestCase):
tf_shift_v = tf.constant(shift_v, name="shift")
else:
tf_shift_v = None
opm, opv = self._opAggregateMoments(tf_counts, tf_mean_ss,
opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss,
tf_variance_ss, tf_shift_v)
tfm, tfv = sess.run([opm, opv])
self.assertAllClose(npm, tfm, atol=0.000001)
self.assertAllClose(npv, tfv, atol=0.000001)
def testAggregateMoments(self):
def testNormalizeMoments(self):
for shift in [True, False]:
self._testAggregateMoments([3], shift)
self._testAggregateMoments([2, 3], shift)
self._testNormalizeMoments([3], shift)
self._testNormalizeMoments([2, 3], shift)
class MomentsTest(tf.test.TestCase):
@ -971,15 +971,15 @@ class MomentsTest(tf.test.TestCase):
"""Make sure the output names are stable."""
with self.test_session():
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False)
self.assertEquals(mean.op.name, "moments/aggregate/mean")
self.assertEquals(var.op.name, "moments/aggregate/variance")
self.assertEquals(mean.op.name, "moments/normalize/mean")
self.assertEquals(var.op.name, "moments/normalize/variance")
def testOutputNamesKeep(self):
"""Make sure the output names are stable."""
with self.test_session():
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True)
self.assertEquals(mean.op.name, "moments/aggregate/mean")
self.assertEquals(var.op.name, "moments/aggregate/variance")
self.assertEquals(mean.op.name, "moments/normalize/mean")
self.assertEquals(var.op.name, "moments/normalize/variance")
class ComputeSampledLogitsTest(tf.test.TestCase):