From 00986d48bb646daab659503ad3a713919865f32d Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 25 Feb 2016 20:10:09 -0800 Subject: [PATCH] Initial version of the open-source distributed TensorFlow runtime. This includes a gRPC server (grpc_tensorflow_server) that can serve as both the master of a distributed TensorFlow computation, and an individual worker in the computation. The GrpcSession class is included to allow client programs (including Python clients) to interact with a server. See tensorflow/core/distributed_runtime/README.md for usage instructions. This change partially addresses issue #23. Change: 115634191 --- WORKSPACE | 31 + tensorflow/core/BUILD | 59 +- tensorflow/core/distributed_runtime/BUILD | 306 ++++++ tensorflow/core/distributed_runtime/README.md | 197 ++++ .../base_rendezvous_mgr.cc | 318 ++++++ .../distributed_runtime/base_rendezvous_mgr.h | 212 ++++ .../build_graph_options.cc | 38 + .../distributed_runtime/build_graph_options.h | 38 + .../core/distributed_runtime/call_options.cc | 44 + .../core/distributed_runtime/call_options.h | 72 ++ .../distributed_runtime/call_options_test.cc | 39 + .../core/distributed_runtime/executor_test.cc | 407 ++++++++ .../core/distributed_runtime/graph_mgr.cc | 368 +++++++ .../core/distributed_runtime/graph_mgr.h | 147 +++ tensorflow/core/distributed_runtime/master.cc | 413 ++++++++ tensorflow/core/distributed_runtime/master.h | 98 ++ .../core/distributed_runtime/master_env.h | 66 ++ .../distributed_runtime/master_interface.h | 52 + .../distributed_runtime/master_session.cc | 942 ++++++++++++++++++ .../core/distributed_runtime/master_session.h | 38 + .../master_session_interface.h | 76 ++ .../core/distributed_runtime/master_test.cc | 423 ++++++++ .../core/distributed_runtime/process_util.cc | 69 ++ .../core/distributed_runtime/process_util.h | 39 + .../core/distributed_runtime/remote_device.cc | 91 ++ .../core/distributed_runtime/remote_device.h | 48 + .../distributed_runtime/remote_device_test.cc | 89 ++ .../rendezvous_mgr_interface.h | 79 ++ tensorflow/core/distributed_runtime/rpc/BUILD | 341 +++++++ .../rpc/async_service_interface.h | 37 + .../core/distributed_runtime/rpc/grpc_call.h | 227 +++++ .../distributed_runtime/rpc/grpc_channel.cc | 314 ++++++ .../distributed_runtime/rpc/grpc_channel.h | 98 ++ .../rpc/grpc_channel_test.cc | 137 +++ .../rpc/grpc_client_cq_tag.h | 56 ++ .../rpc/grpc_master_service.cc | 181 ++++ .../rpc/grpc_master_service.h | 33 + .../rpc/grpc_remote_master.cc | 79 ++ .../rpc/grpc_remote_master.h | 27 + .../rpc/grpc_remote_worker.cc | 203 ++++ .../rpc/grpc_remote_worker.h | 38 + .../rpc/grpc_server_lib.cc | 116 +++ .../distributed_runtime/rpc/grpc_server_lib.h | 53 + .../distributed_runtime/rpc/grpc_session.cc | 233 +++++ .../distributed_runtime/rpc/grpc_session.h | 97 ++ .../rpc/grpc_session_test.cc | 750 ++++++++++++++ .../rpc/grpc_tensorflow_server.cc | 98 ++ .../rpc/grpc_tensorflow_server_lib.cc | 123 +++ .../distributed_runtime/rpc/grpc_testlib.cc | 84 ++ .../distributed_runtime/rpc/grpc_testlib.h | 73 ++ .../rpc/grpc_testlib_ops.cc | 91 ++ .../rpc/grpc_testlib_server.cc | 92 ++ .../core/distributed_runtime/rpc/grpc_util.h | 48 + .../rpc/grpc_worker_cache.cc | 85 ++ .../rpc/grpc_worker_cache.h | 28 + .../rpc/grpc_worker_service.cc | 415 ++++++++ .../rpc/grpc_worker_service.h | 34 + .../rpc/rpc_rendezvous_mgr.cc | 196 ++++ .../rpc/rpc_rendezvous_mgr.h | 57 ++ .../rpc/rpc_rendezvous_mgr_test.cc | 172 ++++ .../simple_graph_execution_state.cc | 309 ++++++ .../simple_graph_execution_state.h | 156 +++ .../core/distributed_runtime/worker_cache.h | 75 ++ .../worker_cache_logger.cc | 110 ++ .../distributed_runtime/worker_cache_logger.h | 81 ++ .../worker_cache_partial.cc | 98 ++ .../worker_cache_partial.h | 56 ++ .../core/distributed_runtime/worker_env.h | 62 ++ .../distributed_runtime/worker_interface.h | 129 +++ tensorflow/core/framework/load_library.cc | 2 +- .../core/platform/default/build_config.bzl | 51 +- tensorflow/core/protobuf/master.proto | 190 ++++ tensorflow/core/protobuf/master_service.proto | 105 ++ tensorflow/core/protobuf/worker.proto | 311 ++++++ tensorflow/core/protobuf/worker_service.proto | 67 ++ tensorflow/python/BUILD | 1 + tensorflow/python/client/session_test.py | 5 +- tensorflow/python/ops/nn.py | 8 +- tensorflow/python/ops/nn_test.py | 28 +- 79 files changed, 11222 insertions(+), 37 deletions(-) create mode 100644 tensorflow/core/distributed_runtime/BUILD create mode 100644 tensorflow/core/distributed_runtime/README.md create mode 100644 tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc create mode 100644 tensorflow/core/distributed_runtime/base_rendezvous_mgr.h create mode 100644 tensorflow/core/distributed_runtime/build_graph_options.cc create mode 100644 tensorflow/core/distributed_runtime/build_graph_options.h create mode 100644 tensorflow/core/distributed_runtime/call_options.cc create mode 100644 tensorflow/core/distributed_runtime/call_options.h create mode 100644 tensorflow/core/distributed_runtime/call_options_test.cc create mode 100644 tensorflow/core/distributed_runtime/executor_test.cc create mode 100644 tensorflow/core/distributed_runtime/graph_mgr.cc create mode 100644 tensorflow/core/distributed_runtime/graph_mgr.h create mode 100644 tensorflow/core/distributed_runtime/master.cc create mode 100644 tensorflow/core/distributed_runtime/master.h create mode 100644 tensorflow/core/distributed_runtime/master_env.h create mode 100644 tensorflow/core/distributed_runtime/master_interface.h create mode 100644 tensorflow/core/distributed_runtime/master_session.cc create mode 100644 tensorflow/core/distributed_runtime/master_session.h create mode 100644 tensorflow/core/distributed_runtime/master_session_interface.h create mode 100644 tensorflow/core/distributed_runtime/master_test.cc create mode 100644 tensorflow/core/distributed_runtime/process_util.cc create mode 100644 tensorflow/core/distributed_runtime/process_util.h create mode 100644 tensorflow/core/distributed_runtime/remote_device.cc create mode 100644 tensorflow/core/distributed_runtime/remote_device.h create mode 100644 tensorflow/core/distributed_runtime/remote_device_test.cc create mode 100644 tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h create mode 100644 tensorflow/core/distributed_runtime/rpc/BUILD create mode 100644 tensorflow/core/distributed_runtime/rpc/async_service_interface.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_call.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_channel.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_channel.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_master_service.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_session.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_session.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_testlib.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_util.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h create mode 100644 tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc create mode 100644 tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h create mode 100644 tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc create mode 100644 tensorflow/core/distributed_runtime/simple_graph_execution_state.cc create mode 100644 tensorflow/core/distributed_runtime/simple_graph_execution_state.h create mode 100644 tensorflow/core/distributed_runtime/worker_cache.h create mode 100644 tensorflow/core/distributed_runtime/worker_cache_logger.cc create mode 100644 tensorflow/core/distributed_runtime/worker_cache_logger.h create mode 100644 tensorflow/core/distributed_runtime/worker_cache_partial.cc create mode 100644 tensorflow/core/distributed_runtime/worker_cache_partial.h create mode 100644 tensorflow/core/distributed_runtime/worker_env.h create mode 100644 tensorflow/core/distributed_runtime/worker_interface.h create mode 100644 tensorflow/core/protobuf/master.proto create mode 100644 tensorflow/core/protobuf/master_service.proto create mode 100644 tensorflow/core/protobuf/worker.proto create mode 100644 tensorflow/core/protobuf/worker_service.proto diff --git a/WORKSPACE b/WORKSPACE index 2e1b018e14f..26bfa1f15f5 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -15,6 +15,37 @@ load("//tensorflow:workspace.bzl", "tf_workspace") tf_workspace() +# grpc expects //external:protobuf_clib and //external:protobuf_compiler +# to point to the protobuf's compiler library. +bind( + name = "protobuf_clib", + actual = "//google/protobuf:protoc_lib", +) + +bind( + name = "protobuf_compiler", + actual = "//google/protobuf:protoc_lib", +) + +git_repository( + name = "grpc", + commit = "73979f4", + init_submodules = True, + remote = "https://github.com/grpc/grpc.git", +) + +# protobuf expects //external:grpc_cpp_plugin to point to grpc's +# C++ plugin code generator. +bind( + name = "grpc_cpp_plugin", + actual = "@grpc//:grpc_cpp_plugin", +) + +bind( + name = "grpc_lib", + actual = "@grpc//:grpc++_unsecure", +) + # TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT new_git_repository( diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 649b21edcb4..54c32706408 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -61,6 +61,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gpu_kernel_library") load( "//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library", + "tf_proto_library_cc", "tf_additional_lib_srcs", "tf_additional_stream_executor_srcs", "tf_additional_test_deps", @@ -77,7 +78,15 @@ load( tf_proto_library( name = "protos_all", - srcs = glob(["**/*.proto"]), + srcs = glob( + ["**/*.proto"], + exclude = [ + "protobuf/worker.proto", + "protobuf/worker_service.proto", + "protobuf/master.proto", + "protobuf/master_service.proto", + ], + ), cc_api_version = 2, go_api_version = 2, java_api_version = 2, @@ -85,6 +94,54 @@ tf_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library_cc( + name = "worker_proto", + srcs = ["protobuf/worker.proto"], + cc_api_version = 2, + cc_libs = [":protos_all_cc"], + visibility = [ + "//tensorflow:internal", + ], +) + +tf_proto_library_cc( + name = "worker_service_proto", + srcs = ["protobuf/worker_service.proto"], + has_services = 1, + cc_api_version = 2, + cc_grpc_version = 1, + cc_libs = [":worker_proto_cc"], + cc_stubby_versions = ["2"], + visibility = [ + "//tensorflow:internal", + ], +) + +tf_proto_library_cc( + name = "master_proto", + srcs = ["protobuf/master.proto"], + cc_api_version = 2, + cc_libs = [":protos_all_cc"], + py_api_version = 2, + visibility = [ + "//tensorflow:internal", + ], +) + +tf_proto_library_cc( + name = "master_service_proto", + srcs = ["protobuf/master_service.proto"], + has_services = 1, + cc_api_version = 2, + cc_grpc_version = 1, + cc_libs = [":master_proto_cc"], + cc_stubby_versions = ["2"], + py_api_version = 2, + visibility = [ + "//tensorflow:internal", + ], +) + cc_library( name = "lib", hdrs = [ diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD new file mode 100644 index 00000000000..00d97a6ef90 --- /dev/null +++ b/tensorflow/core/distributed_runtime/BUILD @@ -0,0 +1,306 @@ +# Description: +# A distributed runtime for TensorFlow, which allows graph execution +# to be distributed and performed in parallel across multiple +# processes. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_tests") + +# For platform specific build config +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) + +package(default_visibility = [ + "//tensorflow:internal", +]) + +cc_library( + name = "worker_env", + hdrs = ["worker_env.h"], + deps = [], +) + +cc_library( + name = "worker_interface", + hdrs = ["worker_interface.h"], + deps = [ + ":call_options", + "//tensorflow/core:lib", + "//tensorflow/core:worker_proto_cc", + ], +) + +cc_library( + name = "call_options", + srcs = ["call_options.cc"], + hdrs = ["call_options.h"], + deps = [ + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "call_options_test", + size = "small", + srcs = ["call_options_test.cc"], + deps = [ + ":call_options", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "worker_cache", + hdrs = ["worker_cache.h"], + deps = ["//tensorflow/core:protos_all_cc"], +) + +cc_library( + name = "remote_device", + srcs = ["remote_device.cc"], + hdrs = ["remote_device.h"], + deps = [ + ":process_util", + ":worker_cache", + ":worker_interface", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core:worker_proto_cc", + ], +) + +cc_library( + name = "master_interface", + hdrs = ["master_interface.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:master_proto_cc", + ], +) + +cc_library( + name = "master", + srcs = ["master.cc"], + hdrs = ["master.h"], + deps = [ + ":call_options", + ":master_env", + ":master_session_interface", + ":process_util", + ":remote_device", + ":worker_cache", + ":worker_interface", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:master_proto_cc", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core:worker_proto_cc", + ], +) + +cc_library( + name = "master_session", + srcs = ["master_session.cc"], + hdrs = ["master_session.h"], + deps = [ + ":master_env", + ":master_session_interface", + ":process_util", + ":simple_graph_execution_state", + ":worker_cache", + ":worker_interface", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:master_proto_cc", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "build_graph_options", + srcs = ["build_graph_options.cc"], + hdrs = ["build_graph_options.h"], + deps = [ + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "simple_graph_execution_state", + srcs = ["simple_graph_execution_state.cc"], + hdrs = ["simple_graph_execution_state.h"], + deps = [ + ":build_graph_options", + ":process_util", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:worker_proto_cc", + ], +) + +cc_library( + name = "rendezvous_mgr_interface", + srcs = [], + hdrs = ["rendezvous_mgr_interface.h"], + deps = [ + "//tensorflow/core:framework", + ], +) + +cc_library( + name = "master_session_interface", + srcs = [], + hdrs = ["master_session_interface.h"], + deps = ["//tensorflow/core:lib"], +) + +cc_library( + name = "base_rendezvous_mgr", + srcs = ["base_rendezvous_mgr.cc"], + hdrs = ["base_rendezvous_mgr.h"], + deps = [ + ":process_util", + ":rendezvous_mgr_interface", + ":worker_cache", + ":worker_env", + ":worker_interface", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + ], +) + +cc_library( + name = "master_env", + hdrs = ["master_env.h"], +) + +cc_library( + name = "graph_mgr", + srcs = ["graph_mgr.cc"], + hdrs = ["graph_mgr.h"], + deps = [ + ":process_util", + ":rendezvous_mgr_interface", + ":worker_env", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core:worker_proto_cc", + ], +) + +cc_library( + name = "process_util", + srcs = ["process_util.cc"], + hdrs = ["process_util.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:tensorflow_opensource", + ], +) + +cc_library( + name = "worker_cache_partial", + srcs = ["worker_cache_partial.cc"], + hdrs = ["worker_cache_partial.h"], + deps = [ + ":process_util", + ":worker_cache", + ":worker_interface", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:worker_proto_cc", + ], +) + +cc_library( + name = "worker_cache_logger", + srcs = ["worker_cache_logger.cc"], + hdrs = ["worker_cache_logger.h"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +# TODO(mrry): Move executor_test.cc to ../common_runtime when once it no longer depends +# on grpc_testlib. +tf_cc_tests( + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + tests = [ + "executor_test.cc", + "master_test.cc", + "remote_device_test.cc", + ], + deps = [ + "@grpc//:grpc++_unsecure", + ":master", + ":process_util", + ":remote_device", + ":worker_interface", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:master_proto_cc", + "//tensorflow/core:master_service_proto_cc", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_testlib", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + ], +) diff --git a/tensorflow/core/distributed_runtime/README.md b/tensorflow/core/distributed_runtime/README.md new file mode 100644 index 00000000000..66433e352ae --- /dev/null +++ b/tensorflow/core/distributed_runtime/README.md @@ -0,0 +1,197 @@ +# Distributed TensorFlow + +This directory contains the initial open-source implementation of the +distributed TensorFlow runtime, using [gRPC](http://grpc.io) for inter-process +communication. + +## Quick start + +To get started, you will need to build the TensorFlow server binary +(`grpc_tensorflow_server`) and a gRPC-based client. Currently this is only +available using the source-based installation of TensorFlow, but it will be +included in future binary releases. You can build the server binary using one of +the following commands: + +```shell +# CPU-only build. +$ bazel build -c opt //tensorflow/core/distributed_runtime/rpc:grpc_tensorflow_server + +# GPU build. +$ bazel build -c opt --config=cuda //tensorflow/core/distributed_runtime/rpc:grpc_tensorflow_server +``` + +If you build the latest Python (PIP) package from source, it will contain a +gRPC-based client. If you are using a previous binary release, you may need to +rebuild and install an up-to-date PIP package by following +[these installation instructions](https://www.tensorflow.org/versions/master/get_started/os_setup.html#create-the-pip-package-and-install). + +Once you have successfully built the distributed TensorFlow components, you can +test your installation by starting a server as follows: + +```shell +# Start a TensorFlow server as a single-process "cluster". +$ bazel-bin/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server \ + --cluster_spec='local|localhost:2222' --job_name=local --task_index=0 & +``` + +...then start a Python interpreter and create a remote session: + +```python +$ python +>>> import tensorflow as tf +>>> c = tf.constant("Hello, distributed TensorFlow!") +>>> sess = tf.Session("grpc://localhost:2222") +>>> sess.run(c) +'Hello, distributed TensorFlow!' +``` + +## Cluster definition + +The command-line arguments to `grpc_tensorflow_server` define the membership of a TensorFlow cluster. The `--cluster_spec` flag determines the set of processes in the cluster, as a list of *jobs*, each of which contains a list of *task* endpoints. All processes in the cluster must be started with the same `--cluster_spec`. Example values include: + + + + + + + + + + + + +
--cluster_spec='...'Available tasks
local|localhost:2222/job:local/task:0
local|localhost:2222;localhost:2223/job:local/task:0
/job:local/task:1
worker|worker0:2222;worker1:2222;worker2:2222,
ps|ps0:2222;ps1:2222
/job:worker/task:0
/job:worker/task:1
/job:worker/task:2
/job:ps/task:0
/job:ps/task:1
+ +The `--job_name` and `--task_index` flags indicate which task will run in this +process, out of the jobs and tasks defined in `--cluster_spec`. For example, +`--job_name=local --task_index=0` means that the process will be task +`/job:local/task:0`, and TensorFlow devices in the process will have names +starting with that prefix. + +**N.B.** Manually specifying these command lines can be tedious, especially for +large clusters. We are working on tools for launching tasks programmatically, +e.g. using a cluster manager like [Kubernetes](http://kubernetes.io). If there +are particular cluster managers for which you'd like to see support, please +raise a [GitHub issue](https://github.com/tensorflow/tensorflow/issues). + +## Specifying distributed devices in your model + +To place operations on a particular process, you can use the same +[`tf.device()`](https://www.tensorflow.org/versions/master/api_docs/python/framework.html#device) +function that is used to specify whether ops run on the CPU or GPU. For example: + +```python +with tf.device("/job:ps/task:0"): + weights_1 = tf.Variable(...) + biases_1 = tf.Variable(...) + +with tf.device("/job:ps/task:1"): + weights_2 = tf.Variable(...) + biases_2 = tf.Variable(...) + +with tf.device("/job:worker/task:7"): + input, labels = ... + layer_1 = tf.nn.relu(tf.matmul(input, weights_1) + biases_1) + logits = tf.nn.relu(tf.matmul(layer_1, weights_2) + biases_2) + # ... + train_op = ... + +with tf.Session("grpc://worker7:2222") as sess: + for _ in range(10000): + sess.run(train_op) +``` + +In the above example, the variables are created on two tasks in the `ps` job, +and the compute-intensive part of the model is created in the `worker` +job. TensorFlow will insert the appropriate data transfers between the jobs +(from `ps` to `worker` for the forward pass, and from `worker` to `ps` for +applying gradients). + +## Replicated training + +A common training configuration ("data parallel training") involves multiple +tasks in a `worker` job training the same model, using shared parameters hosted +in a one or more tasks in a `ps` job. Each task will typically run on a +different machine. There are many ways to specify this structure in TensorFlow, +and we are building libraries that will simplify the work of specifying a +replicated model. Possible approaches include: + +* Building a single graph containing one set of parameters (in `tf.Variable` + nodes pinned to `/job:ps`), and multiple copies of the "model" pinned to + different tasks in `/job:worker`. Each copy of the model can have a different + `train_op`, and one or more client threads can call `sess.run(train_ops[i])` + for each worker `i`. This implements *asynchronous* training. + + This approach uses a single `tf.Session` whose target is one of the workers in + the cluster. + +* As above, but where the gradients from all workers are averaged. See the + [CIFAR-10 multi-GPU trainer](https://www.tensorflow.org/code/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py) + for an example of this form of replication. The implements *synchronous* training + +* The "distributed trainer" approach uses multiple graphs—one per + worker—where each graph contains one set of parameters (pinned to + `/job:ps`) and one copy of the model (pinned to a particular + `/job:worker/task:i`). The "container" mechanism is used to share variables + between different graphs: when each variable is constructed, the optional + `container` argument is specified with the same value in each copy of the + graph. For large models, this can be more efficient, because the overall graph + is smaller. + + This approach uses multiple `tf.Session` objects: one per worker process, + where the `target` of each is the address of a different worker. The + `tf.Session` objects can all be created in a single Python client, or you can + use multiple Python clients to better distribute the trainer load. + +## Glossary + +
+
Client
+
+ A client is typically a program that builds a TensorFlow graph and + constructs a `tensorflow::Session` to interact with a cluster. Clients are + typically written in Python or C++. A single client process can directly + interact with multiple TensorFlow servers (see "Replicated training" above), + and a single server can serve multiple clients. +
+
Cluster
+
+ A TensorFlow cluster comprises one or more TensorFlow servers, divided into + a set of named jobs, which in turn comprise lists of tasks. A cluster is + typically dedicated to a particular high-level objective, such as training a + neural network, using many machines in parallel. +
+
Job
+
+ A job comprises a list of "tasks", which typically serve a common + purpose. For example, a job named `ps` (for "parameter server") typically + hosts nodes that store and update variables; while a job named `worker` + typically hosts stateless nodes that perform compute-intensive tasks. + The tasks in a job typically run on different machines. +
+
Master service
+
+ An RPC service that provides remote access to a set of distributed + devices. The master service implements the tensorflow::Session + interface, and is responsible for coordinating work across one or more + "worker services". +
+
Task
+
+ A task typically corresponds to a single TensorFlow server process, + belonging to a particular "job" and with a particular index within that + job's list of tasks. +
+ +
TensorFlow server
+
+ A process running the grpc_tensorflow_server binary, which is a + member of a cluster, and exports a "master service" and "worker service". +
+
Worker service
+
+ An RPC service that executes parts of a TensorFlow graph using its local + devices. A worker service implements worker_service.proto. +
+
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc new file mode 100644 index 00000000000..af5d1272487 --- /dev/null +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -0,0 +1,318 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" + +#include +#include + +#include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* env) : worker_env_(env) {} + +BaseRendezvousMgr::~BaseRendezvousMgr() { + for (auto& p : table_) { + BaseRemoteRendezvous* rendez = p.second; + rendez->StartAbort(errors::Aborted("Shutdown")); + rendez->Unref(); + } +} + +Rendezvous* BaseRendezvousMgr::Find(int64 step_id) { + return FindOrCreate(step_id); +} + +BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) { + mutex_lock l(mu_); + Table::iterator iter = table_.find(step_id); + if (iter == table_.end()) { + auto rr = Create(step_id, worker_env_); + iter = table_.insert({step_id, rr}).first; + } + iter->second->Ref(); + return iter->second; +} + +void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, const string& key, + Rendezvous::DoneCallback done) { + BaseRemoteRendezvous* rendez = FindOrCreate(step_id); + rendez->RecvLocalAsync( + key, [rendez, done](const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& v, + bool dead) { + rendez->Unref(); + done(s, send_args, recv_args, v, dead); + }); +} + +Status BaseRendezvousMgr::RecvLocal(int64 step_id, const string& key, + Tensor* val, bool* is_dead) { + Status ret; + Notification n; + RecvLocalAsync(step_id, key, + [val, is_dead, &ret, &n](const Status& s, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& v, const bool dead) { + ret = s; + *val = v; + *is_dead = dead; + n.Notify(); + }); + n.WaitForNotification(); + return ret; +} + +void BaseRendezvousMgr::Cleanup(int64 step_id) { + Rendezvous* rendez = nullptr; + { + mutex_lock l(mu_); + Table::iterator iter = table_.find(step_id); + if (iter != table_.end()) { + rendez = iter->second; + table_.erase(iter); + } + } + if (!rendez) return; + rendez->StartAbort(errors::Aborted("Cleanup ", step_id)); + rendez->Unref(); +} + +void BaseRendezvousMgr::CleanupAll() { + std::vector rendezs; + { + mutex_lock l(mu_); + for (const auto& entry : table_) { + rendezs.push_back(entry.second); + } + table_.clear(); + } + for (auto rendez : rendezs) { + rendez->StartAbort(errors::Aborted("Shutdown")); + rendez->Unref(); + } +} + +BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id, + bool tolerate_dup_recv) + : env_(env), + step_id_(step_id), + tolerate_dup_recv_(tolerate_dup_recv), + local_(NewLocalRendezvous(tolerate_dup_recv)) {} + +BaseRemoteRendezvous::~BaseRemoteRendezvous() { + CHECK(active_.empty()); + local_->Unref(); +} + +// Returns true if "device_name" is a valid full name of local device +// of the "worker". This helper is purely based on the worker name +// and device name and does no lookups in the worker->device_mgr. +static bool IsLocalDevice(const WorkerEnv& worker, + const StringPiece device_name) { + return device_name.starts_with(worker.worker_name); +} + +Status BaseRemoteRendezvous::Send(const string& key, + const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) { + VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << key; + { + mutex_lock l(mu_); + if (!status_.ok()) return status_; + } + Rendezvous::ParsedKey parsed; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed)); + if (!IsLocalDevice(*env_, parsed.src_device)) { + return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ", + env_->worker_name); + } + // Buffers "val" and "device_context" in local_. + return local_->Send(key, args, val, is_dead); +} + +Status BaseRemoteRendezvous::ParseKey(const string& key, bool is_src, + Rendezvous::ParsedKey* parsed) { + { + mutex_lock l(mu_); + if (!status_.ok()) return status_; + } + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed)); + if (is_src && !IsLocalDevice(*env_, parsed->src_device)) { + return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ", + env_->worker_name); + } + if (!is_src && !IsLocalDevice(*env_, parsed->dst_device)) { + return errors::InvalidArgument("Invalid rendezvous key (dst): ", key, " @ ", + env_->worker_name); + } + return Status::OK(); +} + +void BaseRemoteRendezvous::SameWorkerRecvDone( + const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out, + StatusCallback done) { + // Do a quick copy (sharing the underlying buffer) if both tensors + // are on host memory. + const bool src_host = + (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU"); + const bool dst_host = + (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU"); + if (src_host && dst_host) { + *out = in; + done(Status::OK()); + return; + } + + // This copy must involve a GPU. Hence, "in" must support DMA + // (e.g., string tensors do not work on GPU). + if (!DMAHelper::CanUseDMA(&in)) { + done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()), + " tensor may not be copied from/to a GPU.")); + return; + } + + Device* src_device; + Status s = env_->device_mgr->LookupDevice(parsed.src_device, &src_device); + if (!s.ok()) { + done(s); + return; + } + Device* dst_device; + s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + if (!s.ok()) { + done(s); + return; + } + + AllocatorAttributes attr = recv_args.alloc_attrs; + attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() || + recv_args.alloc_attrs.gpu_compatible()); + Allocator* out_allocator = dst_device->GetAllocator(attr); + Tensor copy(out_allocator, in.dtype(), in.shape()); + *out = copy; + + // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies, + // etc. + CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context, + recv_args.device_context, src_device, dst_device, + send_args.alloc_attrs, recv_args.alloc_attrs, &in, out, + done); +} + +bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src, + DeviceNameUtils::ParsedName dst) { + return DeviceNameUtils::IsSameAddressSpace(src, dst); +} + +void BaseRemoteRendezvous::RecvAsync(const string& key, + const Rendezvous::Args& recv_args, + DoneCallback done) { + VLOG(1) << "RemoteRendezvous Recv " << this << " " << key; + + Rendezvous::ParsedKey parsed; + Status s = ParseKey(key, false /*!is_src*/, &parsed); + if (!s.ok()) { + done(s, Args(), recv_args, Tensor(), false); + return; + } + + // Are src and dst in the same worker? + if (IsSameWorker(parsed.src, parsed.dst)) { + // Recv the tensor from local_. + local_->RecvAsync( + key, recv_args, [this, parsed, done](const Status& status, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& in, bool is_dead) { + Status s = status; + Tensor* out = new Tensor; + StatusCallback final_callback = [done, send_args, recv_args, out, + is_dead](const Status& s) { + done(s, send_args, recv_args, *out, is_dead); + delete out; + }; + + if (s.ok()) { + SameWorkerRecvDone(parsed, send_args, recv_args, in, out, + final_callback); + } else { + final_callback(s); + } + }); + return; + } else { + RecvFromRemoteAsync(key, parsed, recv_args, done); + } +} + +void BaseRemoteRendezvous::RecvLocalAsync(const string& key, + DoneCallback done) { + Rendezvous::ParsedKey parsed; + Status s = ParseKey(key, true /* is_src */, &parsed); + if (!s.ok()) { + done(s, Args(), Args(), Tensor(), false); + return; + } + local_->RecvAsync(key, Args(), done); +} + +void BaseRemoteRendezvous::StartAbort(const Status& s) { + CHECK(!s.ok()); + local_->StartAbort(s); + { + // Aborts all active RecvTensor calls. + mutex_lock l(mu_); + if (status_.ok()) { + status_ = s; + for (BaseRecvTensorCall* call : active_) { + call->StartAbort(s); + } + active_.clear(); + } + } +} + +void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call) { + mutex_lock l(mu_); + if (!status_.ok()) { + call->StartAbort(status_); + } else { + CHECK(active_.insert(call).second); + } +} + +void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) { + mutex_lock l(mu_); + active_.erase(call); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h new file mode 100644 index 00000000000..26748174265 --- /dev/null +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h @@ -0,0 +1,212 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ + +#include +#include +#include + +#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +class BaseRemoteRendezvous; +class BaseRecvTensorCall; + +// RendezvousMgr keeps track of a set of local rendezvous instances. +// All tensors sent by this worker are buffered in a RendezvousMgr +// until the tensor is received. Each global unique "step_id" +// corresponds to one local rendezvous instance managed by a +// RendezvousMgr. +// +// E.g., +// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); +// fork execution of a graph executor using "rendez" on thread 1; +// fork execution of another graph executor using "rendez" on thread 2; +// ... +// join threads 1 and 2; +// +// In the example above, execution in thread 1 and 2 communicates with +// each other by send/recv operations through `rendez`. +// +// Tensors sent and received through a rendezvous managed by this +// RendezvousMgr must have keys generated by Rendezvous::CreateKey(). +class BaseRendezvousMgr : public RendezvousMgrInterface { + public: + explicit BaseRendezvousMgr(const WorkerEnv* worker_env); + ~BaseRendezvousMgr() override; + + // Returns Rendezvous supporting send and recv among workers in the + // "step_id". The caller takes ownership of one reference on the + // returned Rendezvous instance. + Rendezvous* Find(int64 step_id) override; + + // Finds the local rendezvous instance for the "step_id". Runs + // "done" when the tensor for "key" is produced or an error occurs. + // + // This method is used by the rpc handler of RecvTensor. + void RecvLocalAsync(int64 step_id, const string& key, + Rendezvous::DoneCallback done) override; + + // Synchronous wrapper for RecvLocalAsync. + Status RecvLocal(int64 step_id, const string& key, Tensor* val, + bool* is_dead) override; + + // Removes rendezvous for "step_id". + // + // TODO(zhifengc): Have a background thread in worker that + // periodically calls CleanupAll(). + void Cleanup(int64 step_id) override; + + // Removed all rendezvous. + void CleanupAll() override; + + protected: + virtual BaseRemoteRendezvous* Create(int64 step_id, + const WorkerEnv* worker_env) = 0; + + private: + // Maps step_id to rendezvous. + typedef std::unordered_map Table; + + // Not owned. + const WorkerEnv* const worker_env_; + + mutex mu_; + Table table_ GUARDED_BY(mu_); + + BaseRemoteRendezvous* FindOrCreate(int64 step_id); + + TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr); +}; + +// RemoteRendezvous is a Rendezvous which can handle either +// the producer or consumer being in a remote process. +// +// Buffering of Tensor values is delegated to a "local" Rendezvous +// obtained from NewLocalRendezvous(). This class just adds +// functionality to coordinate with remote workers. +class BaseRemoteRendezvous : public Rendezvous { + public: + BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id, + bool tolerate_dup_recv); + + // Forwards to local_, where the Tensor "val" will be buffered and + // any waiting callback stored. + Status Send(const string& key, const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) override; + + // This method is called only by the RecvOp. It tests to see + // whether the value will be produced by a local or remote device + // and handles accordingly. In the local case it forwards to + // local_, in the remote case it initiates an RPC request. + void RecvAsync(const string& key, const Rendezvous::Args& args, + DoneCallback done) override; + + void StartAbort(const Status& status) override; + + // This method is called only by the local Worker, forwarded through + // the same method on RendezvousMgr. This occurs when the Worker + // has received a RecvTensor request, either locally or over the + // network. In either case it needs to retrieve a locally buffered + // value from local_, and give it to its caller. + // + // Runs "done" as soon as the tensor for "key" is available or an error + // is detected. + // + // REQUIRES: "key" is one that will be Saved into the local rendezvous. + void RecvLocalAsync(const string& key, DoneCallback done); + + protected: + virtual void RecvFromRemoteAsync(const string& key, + const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& args, + DoneCallback done) = 0; + + // Returns true if "src" and "dst" are located in the same worker, + // and hence may use a local rendezvous. + virtual bool IsSameWorker(DeviceNameUtils::ParsedName src, + DeviceNameUtils::ParsedName dst); + + // If aborted, aborts "call". Otherwise, adds "call" into active_. + void RegisterCall(BaseRecvTensorCall* call); + + // Removes "call" from active_ if "call" is in active_. + void DeregisterCall(BaseRecvTensorCall* call); + + ~BaseRemoteRendezvous() override; + + const WorkerEnv* const env_; // Not owned. + const int64 step_id_; + + private: + const bool tolerate_dup_recv_; + Rendezvous* local_; // Owns a Ref on this object. + + mutable mutex mu_; + + // Status given by StartAbort() if any. + Status status_ GUARDED_BY(mu_); + + // Active outstanding RecvTensor calls. + std::unordered_set active_ GUARDED_BY(mu_); + + // Parses "key" into "parsed". If "is_src" is true, checks that the + // rendezvous key's source is in this process. If "is_src" is false, + // checks that the rendezvous key's destination is in this process. + Status ParseKey(const string& key, bool is_src, + Rendezvous::ParsedKey* parsed); + + // Callback handling the case when a rendezvous has been + // accomplished in local_ and the consumer is local to this process. + // Tensor "in" will be copied into "out". The key "parsed" encodes + // the src and dst devices. + void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& in_args, + const Rendezvous::Args& out_args, const Tensor& in, + Tensor* out, StatusCallback done); + + TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); +}; + +class BaseRecvTensorCall { + public: + BaseRecvTensorCall() {} + virtual ~BaseRecvTensorCall() {} + + virtual void Start(std::function recv_done) = 0; + + virtual void StartAbort(const Status& s) = 0; + + virtual Status status() const = 0; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ diff --git a/tensorflow/core/distributed_runtime/build_graph_options.cc b/tensorflow/core/distributed_runtime/build_graph_options.cc new file mode 100644 index 00000000000..05c42e89ba6 --- /dev/null +++ b/tensorflow/core/distributed_runtime/build_graph_options.cc @@ -0,0 +1,38 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/build_graph_options.h" + +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +string BuildGraphOptions::DebugString() const { + string rv = "Feed endpoints: "; + for (auto& s : feed_endpoints) { + strings::StrAppend(&rv, s, ", "); + } + strings::StrAppend(&rv, "\nFetch endpoints: "); + for (auto& s : fetch_endpoints) { + strings::StrAppend(&rv, s, ", "); + } + strings::StrAppend(&rv, "\nTarget nodes: "); + for (auto& s : target_nodes) { + strings::StrAppend(&rv, s, ", "); + } + return rv; +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/build_graph_options.h b/tensorflow/core/distributed_runtime/build_graph_options.h new file mode 100644 index 00000000000..438912642d1 --- /dev/null +++ b/tensorflow/core/distributed_runtime/build_graph_options.h @@ -0,0 +1,38 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +struct BuildGraphOptions { + std::vector feed_endpoints; + std::vector fetch_endpoints; + + // TODO(vrv): Remove this when we unify target_nodes and fetch_endpoint, + // the former via "ref" fetch_endpoints. + std::vector target_nodes; + + string DebugString() const; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_ diff --git a/tensorflow/core/distributed_runtime/call_options.cc b/tensorflow/core/distributed_runtime/call_options.cc new file mode 100644 index 00000000000..b9d583b754e --- /dev/null +++ b/tensorflow/core/distributed_runtime/call_options.cc @@ -0,0 +1,44 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/call_options.h" + +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +CallOptions::CallOptions() {} + +void CallOptions::StartCancel() { + mutex_lock l(mu_); + if (cancel_func_ != nullptr) { + // NOTE: We must call the cancel_func_ with mu_ held. This ensure + // that ClearCancelCallback() does not race with StartCancel(). + cancel_func_(); + // NOTE: We can clear cancel_func_ if needed. + } +} + +void CallOptions::SetCancelCallback(CancelFunction cancel_func) { + mutex_lock l(mu_); + cancel_func_ = cancel_func; +} + +void CallOptions::ClearCancelCallback() { + mutex_lock l(mu_); + cancel_func_ = nullptr; +} + +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/call_options.h b/tensorflow/core/distributed_runtime/call_options.h new file mode 100644 index 00000000000..de0b85f6926 --- /dev/null +++ b/tensorflow/core/distributed_runtime/call_options.h @@ -0,0 +1,72 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_ + +#include + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Options passed to interface calls. This class provides portable +// functionality across different RPC systems on top of +// platform-specific mechanisms (for client and server contexts, +// cancellation, etc.). +// +// TODO(zhifengc): Maybe change all RPC methods to take CallOptions. +class CallOptions { + public: + CallOptions(); + + // Cancellation. + // + // The caller may call StartCancel() anytime as long as this + // CallOptions object is alive. The callee may or may not receive + // the cancellation notification depending on the rpc layer + // implementation. + void StartCancel(); + + // The callee (the rpc layer implementation) must set a cancellation + // notifier before its blocking operation and clear the notifier + // before the call returns. + // + // "cancel_func" may be called zero, once or more time. Therefore, it + // should _not_ be responsible for memory management of any objects. + // + // "cancel_func" must be very light-weight. It should not block on + // IO or locking. Typically, it just calls the rpc implementation + // layer's specific cancellation mechanism and does nothing else. + // + // NOTE: "cancel_func" itself is pass-by-value. Therefore, we do not + // worry about its ownership here. + typedef std::function CancelFunction; + void SetCancelCallback(CancelFunction cancel_func); + void ClearCancelCallback(); + + private: + mutex mu_; + CancelFunction cancel_func_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(CallOptions); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_ diff --git a/tensorflow/core/distributed_runtime/call_options_test.cc b/tensorflow/core/distributed_runtime/call_options_test.cc new file mode 100644 index 00000000000..62fe21341c0 --- /dev/null +++ b/tensorflow/core/distributed_runtime/call_options_test.cc @@ -0,0 +1,39 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/call_options.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(CallOptions, Cancel) { + int num_calls = 0; + CallOptions opts; + opts.StartCancel(); + EXPECT_EQ(num_calls, 0); + opts.SetCancelCallback([&num_calls]() { num_calls++; }); + EXPECT_EQ(num_calls, 0); + opts.StartCancel(); + EXPECT_EQ(num_calls, 1); + opts.StartCancel(); + EXPECT_EQ(num_calls, 2); + opts.ClearCancelCallback(); + EXPECT_EQ(num_calls, 2); + opts.StartCancel(); + EXPECT_EQ(num_calls, 2); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/executor_test.cc b/tensorflow/core/distributed_runtime/executor_test.cc new file mode 100644 index 00000000000..be46c73aa2e --- /dev/null +++ b/tensorflow/core/distributed_runtime/executor_test.cc @@ -0,0 +1,407 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/common_runtime/step_stats_collector.h" +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class ExecutorTest : public ::testing::Test { + protected: + ExecutorTest() + : device_(DeviceFactory::NewDevice("CPU", {}, + "/job:localhost/replica:0/task:0")), + + step_stats_collector_(&step_stats_) { + SessionOptions options; + thread_pool_ = ComputePool(options); + } + + ~ExecutorTest() override { + // There should always be exactly one Ref left on the Rendezvous + // when the test completes. + CHECK(rendez_->Unref()); + delete exec_; + delete device_; + } + + // Resets executor_ with a new executor based on a graph 'gdef'. + void Create(const Graph* graph) { + const int version = graph->versions().producer(); + LocalExecutorParams params; + params.device = device_; + params.create_kernel = [this, version](const NodeDef& ndef, + OpKernel** kernel) { + return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); + }; + params.delete_kernel = [](OpKernel* kernel) { + DeleteNonCachedKernel(kernel); + }; + delete exec_; + TF_CHECK_OK(NewLocalExecutor(params, graph, &exec_)); + runner_ = [this](std::function fn) { thread_pool_->Schedule(fn); }; + rendez_ = NewLocalRendezvous(); + } + + Status Run(Rendezvous* rendez) { + Executor::Args args; + args.rendezvous = rendez; + args.stats_collector = &step_stats_collector_; + args.runner = runner_; + return exec_->Run(args); + } + + thread::ThreadPool* thread_pool_ = nullptr; + Device* device_ = nullptr; + Executor* exec_ = nullptr; + StepStatsCollector step_stats_collector_; + StepStats step_stats_; + Executor::Args::Runner runner_; + Rendezvous* rendez_ = nullptr; +}; + +// A float val -> Tensor +Tensor V(const float val) { + Tensor tensor(DT_FLOAT, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// A int32 val -> Tensor +Tensor VI(const int32 val) { + Tensor tensor(DT_INT32, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// A bool val -> Tensor +Tensor VB(const bool val) { + Tensor tensor(DT_BOOL, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// A double val -> Tensor +Tensor VD(const double val) { + Tensor tensor(DT_DOUBLE, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// Tensor -> a float val. +float V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_FLOAT); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar()(); +} + +static uint64 kIncarnation = 1; // Uses in following tests. + +string Key(const string& sender, const uint64 incarnation, + const string& receiver, const string& name) { + return Rendezvous::CreateKey(sender, incarnation, receiver, name, + FrameAndIter(0, 0)); +} + +#define ALICE "/job:j/replica:0/task:0/cpu:0" +#define BOB "/job:j/replica:0/task:0/gpu:0" + +TEST_F(ExecutorTest, SimpleAdd) { + // c = a + b + Graph* g = new Graph(OpRegistry::Global()); + auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB); + auto tmp = test::graph::Add(g, in0, in1); + test::graph::Send(g, tmp, "c", BOB, 1, ALICE); + Create(g); + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), + false)); // in0 = 1.0 + TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), args, V(1.0), + false)); // in1 = 1.0 + TF_ASSERT_OK(Run(rendez_)); + Tensor out = V(-1); + bool is_dead = false; + TF_ASSERT_OK( + rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead)); + EXPECT_EQ(2.0, V(out)); // out = 1.0 + 1.0 = 2.0 +} + +TEST_F(ExecutorTest, SelfAdd) { + // v0 <- a + // v1 = v0 + v0 + // v2 = v1 + v1 + // ... ... + // v10 = v9 + v9 + // + // b <- v10 + // All nodes are executed by one thread. + Graph* g = new Graph(OpRegistry::Global()); + auto v = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); + const int N = 10; + for (int i = 1; i <= N; ++i) { + v = test::graph::Add(g, v, v); + } + // out <- v10 + test::graph::Send(g, v, "b", BOB, 1, ALICE); + Create(g); + Rendezvous::Args args; + // a = 1.0 + TF_ASSERT_OK( + rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); + TF_ASSERT_OK(Run(rendez_)); + Tensor out = V(-1); + bool is_dead = false; + TF_ASSERT_OK( + rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead)); + EXPECT_EQ(1024.0, V(out)); // b=v10=2*v9=4*v8=...=1024*a=1024.0 +} + +// Builds a graph which adds N copies of one variable "in". I.e., +// a + a + a + ... + a +// The returned graph is parenthesized ramdonly. I.e., +// a + ((a + a) + a) +// (a + a) + (a + a) +// ((a + a) + a) + a +// are all possibly generated. +void BuildTree(int N, Graph* g) { + CHECK_GT(N, 1); + // A single input node "in". + auto in = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); + std::vector nodes; + int i = 0; + // Duplicate "in" N times. Each copies is named as l0, l1, l2, .... + for (; i < N; ++i) { + nodes.push_back(test::graph::Identity(g, in, 0)); + } + random::PhiloxRandom philox(testing::RandomSeed(), 17); + random::SimplePhilox rnd(&philox); + while (nodes.size() > 1) { + // Randomly pick two from nodes and add them. The resulting node + // is named lik n10, n11, .... and is put back into "nodes". + int x = rnd.Uniform(nodes.size()); + auto in0 = nodes[x]; + nodes[x] = nodes.back(); + nodes.resize(nodes.size() - 1); + x = rnd.Uniform(nodes.size()); + auto in1 = nodes[x]; + // node = in0 + in1. + nodes[x] = test::graph::Add(g, in0, in1); + } + // The final output node "out". + test::graph::Send(g, nodes.back(), "b", BOB, 1, ALICE); +} + +TEST_F(ExecutorTest, RandomTree) { + Graph* g = new Graph(OpRegistry::Global()); + BuildTree(4096, g); + Create(g); + Rendezvous::Args args; + TF_ASSERT_OK( + rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); + TF_ASSERT_OK(Run(rendez_)); + Tensor out = V(-1); + bool is_dead = false; + TF_ASSERT_OK( + rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead)); + EXPECT_EQ(4096.0, V(out)); +} + +void BuildConcurrentAddAssign(Graph* g) { + auto one = test::graph::Constant(g, V(1.0)); + // A variable holds one float. + auto var = test::graph::Var(g, DT_FLOAT, TensorShape({})); + // Initilize the variable with 1.0. + auto init = test::graph::Assign(g, var, one); + // Output + auto out = test::graph::Send(g, var, "out", ALICE, kIncarnation, BOB); + // Have many concurrent computation. Each does v = v + 1. + for (int i = 0; i < 1024; ++i) { + auto add = test::graph::Add(g, var, one); + g->AddControlEdge(init, add); // Ensures run after init. + auto assign = test::graph::Assign(g, var, add); + g->AddControlEdge(assign, out); + } +} + +#ifndef THREAD_SANITIZER +TEST_F(ExecutorTest, ConcurrentAddAssign) { + Graph* g = new Graph(OpRegistry::Global()); + BuildConcurrentAddAssign(g); + Create(g); + for (int iters = 0; iters < 16; ++iters) { + Rendezvous* rendez = NewLocalRendezvous(); + TF_ASSERT_OK(Run(rendez)); + Rendezvous::Args args; + Tensor out; + bool is_dead; + TF_ASSERT_OK(rendez->Recv(Key(ALICE, kIncarnation, BOB, "out"), args, &out, + &is_dead)); + VLOG(1) << "Get " << V(out); + EXPECT_LE(V(out), 1025.0); + rendez->Unref(); + } +} +#endif + +TEST_F(ExecutorTest, SimpleSwitchLive) { + Graph* g = new Graph(OpRegistry::Global()); + auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Constant(g, VB(false)); + auto tmp = test::graph::Switch(g, in0, in1); + test::graph::Send(g, tmp, "c", BOB, 1, ALICE); + Create(g); + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), + false)); // in0 = 1.0 + TF_ASSERT_OK(Run(rendez_)); + Tensor out = V(-1); + bool is_dead = false; + TF_ASSERT_OK( + rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead)); + EXPECT_EQ(1.0, V(out)); // out = 1.0 + EXPECT_FALSE(is_dead); +} + +TEST_F(ExecutorTest, SimpleSwitchDead) { + Graph* g = new Graph(OpRegistry::Global()); + auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Constant(g, VB(true)); + auto tmp = test::graph::Switch(g, in0, in1); + test::graph::Send(g, tmp, "c", BOB, 1, ALICE); + Create(g); + Rendezvous::Args args; + TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), + false)); // in0 = 1.0 + TF_ASSERT_OK(Run(rendez_)); + Tensor out = V(-1); + bool is_dead = false; + TF_ASSERT_OK( + rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead)); + EXPECT_TRUE(is_dead); +} + +TEST_F(ExecutorTest, Abort) { + // e = a + b + c + d + Graph* g = new Graph(OpRegistry::Global()); + auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB); + auto in2 = test::graph::Recv(g, "c", "float", ALICE, 1, BOB); + auto in3 = test::graph::Recv(g, "d", "float", ALICE, 1, BOB); + auto add0 = test::graph::Add(g, in0, in1); + auto add1 = test::graph::Add(g, in2, in3); + auto add2 = test::graph::Add(g, add0, add1); + test::graph::Send(g, add2, "e", BOB, 1, ALICE); + Create(g); + + // Needs 4 inputs (recv). One of them is aborted. + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(100 * 1000); + Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), + Rendezvous::Args(), V(1.0), false); + rendez_->Unref(); + }); + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(100 * 1000); + Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), + Rendezvous::Args(), V(1.0), false); + rendez_->Unref(); + }); + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(100 * 1000); + Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "c"), + Rendezvous::Args(), V(1.0), false); + rendez_->Unref(); + }); + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(100 * 1000); + rendez_->StartAbort(errors::Aborted("")); + rendez_->Unref(); + }); + EXPECT_TRUE(errors::IsAborted(Run(rendez_))); + Tensor out = V(-1); + bool is_dead = false; + EXPECT_TRUE(errors::IsAborted(rendez_->Recv( + Key(BOB, kIncarnation, ALICE, "c"), Rendezvous::Args(), &out, &is_dead))); + // At this point there can still be pending (albeit Aborted) Send + // closures holding Refs on rendez_. We need to wait for them, or + // else there can be a memory leak at termination. + while (!rendez_->RefCountIsOne()) + ; +} + +TEST_F(ExecutorTest, RecvInvalidDtype) { + Graph* g = new Graph(OpRegistry::Global()); + // An input vector of type float of size 1. + auto one = test::graph::Recv(g, "one", "float", ALICE, 1, BOB); + // A floating point variable vector of size 1. + auto var = test::graph::Var(g, DT_FLOAT, TensorShape({1})); + // Initialize the variable with input. + auto init = test::graph::Assign(g, var, one); + // Output + auto* two = test::graph::Send(g, var, "two", BOB, 1, ALICE); + g->AddControlEdge(init, two); // Ensures run after init. + Create(g); + Rendezvous* rendez = NewLocalRendezvous(); + // Send a double instead of float. + TF_ASSERT_OK(rendez->Send(Key(ALICE, 1, BOB, "one"), Rendezvous::Args(), + VD(1.0), false)); + // Fails due to invalid dtype. + EXPECT_TRUE(errors::IsInternal(Run(rendez))); + Tensor output; + bool is_dead; + EXPECT_TRUE(errors::IsInternal(rendez->Recv( + Key(BOB, 1, ALICE, "two"), Rendezvous::Args(), &output, &is_dead))); + rendez->Unref(); +} + +TEST_F(ExecutorTest, RecvInvalidRefDtype) { + Graph* g = new Graph(OpRegistry::Global()); + // A var that always produces as invalid dtype. + auto var = test::graph::InvalidRefType(g, DT_FLOAT, DT_DOUBLE); + test::graph::Send(g, var, "out", BOB, 1, ALICE); + Create(g); + Rendezvous* rendez = NewLocalRendezvous(); + EXPECT_TRUE(errors::IsInternal(Run(rendez))); + Tensor output; + bool is_dead; + EXPECT_TRUE(errors::IsInternal(rendez->Recv( + Key(BOB, 1, ALICE, "out"), Rendezvous::Args(), &output, &is_dead))); + rendez->Unref(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc new file mode 100644 index 00000000000..f1bcbf39567 --- /dev/null +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -0,0 +1,368 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/graph_mgr.h" + +#include + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/common_runtime/memory_types.h" +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/config.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_partition.h" +#include "tensorflow/core/graph/validate.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +GraphMgr::GraphMgr(const WorkerEnv* worker_env) + : worker_env_(worker_env), table_(5) {} + +GraphMgr::~GraphMgr() { + for (auto p : table_) p.second->Unref(); +} + +GraphMgr::Item::~Item() { + for (const auto& unit : this->units) { + CHECK_NOTNULL(unit.device); + delete unit.root; + delete unit.lib; + unit.device->op_segment()->RemoveHold(this->session); + } + delete this->lib_def; +} + +// NOTE: node->device_name() is not set by GraphConstructor. We +// expects that NodeDef in GraphDef given to workers fully specifies +// device names. +static string SplitByDevice(const Node* node) { + return node->assigned_device_name(); +} + +// Validates "gdef" device specifications. +static Status ValidateGraphDefForDevices(const GraphDef& gdef) { + DeviceNameUtils::ParsedName parsed; + for (const auto& ndef : gdef.node()) { + if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) { + return errors::InvalidArgument("Missing device name in: ", + SummarizeNodeDef(ndef)); + } + } + return Status::OK(); +} + +// Creates executors given a graph definition "gdef" of a "session". +// If a node in "gdef" is shared by other graphs in "session", the +// same op kernel is reused. E.g., typically a params node is shared +// by multiple graphs in a session. +// +// If "gdef" is assigned to multiple devices, extra nodes (e.g., +// send/recv nodes) maybe added. The extra nodes' name are generated +// by calling "new_name(old_name)". +// +// "executors" are filled with one executor per device if success and +// the caller takes the ownership of returned executors. +Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, + const GraphOptions& graph_options, Item* item) { + item->session = session; + item->lib_def = new FunctionLibraryDefinition(gdef.library()); + + TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef)); + + if (gdef.versions().producer() >= 5) { + // Validate the graph: we assume that merging two valid graphs + // should maintain graph validity. + TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *item->lib_def)); + } + + // Constructs the graph out of "gdef". + Graph graph(item->lib_def); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + opts.expect_device_spec = true; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph)); + + // Splits "graph" into multiple subgraphs by device names. + std::unordered_map partitions; + PartitionOptions popts; + popts.node_to_loc = SplitByDevice; + popts.new_name = [this](const string& prefix) { + mutex_lock l(mu_); + return strings::StrCat(prefix, "_G", next_id_++); + }; + popts.get_incarnation = [this](const string& name) { + Device* device = nullptr; + Status s = worker_env_->device_mgr->LookupDevice(name, &device); + if (s.ok()) { + return device->attributes().incarnation(); + } else { + return PartitionOptions::kIllegalIncarnation; + } + }; + popts.control_flow_added = true; + popts.scheduling_for_recvs = graph_options.enable_recv_scheduling(); + TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions)); + if (popts.scheduling_for_recvs) { + TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions)); + } + + thread::ThreadPool* pool = worker_env_->compute_pool; + auto runner = [pool](std::function fn) { pool->Schedule(fn); }; + + LocalExecutorParams params; + + Status s; + item->units.reserve(partitions.size()); + const auto& optimizer_opts = graph_options.optimizer_options(); + GraphOptimizer optimizer(optimizer_opts); + for (auto&& p : partitions) { + const string& device_name = p.first; + GraphDef* def = &p.second; + item->units.resize(item->units.size() + 1); + ExecutionUnit* unit = &(item->units.back()); + + // Find the device. + s = worker_env_->device_mgr->LookupDevice(device_name, &unit->device); + if (!s.ok()) break; + + // Construct the subgraph. + Graph* subgraph = new Graph(item->lib_def); + // Give the device an opportunity to rewrite its subgraph. + unit->device->MaybeRewriteGraph(gdef.library(), def); + s = ConvertGraphDefToGraph(opts, *def, subgraph); + if (!s.ok()) { + delete subgraph; + break; + } + // Top-level nodes in the graph uses the op segment to cache + // kernels. Therefore, as long as the executor is alive, we need + // to ensure the kernels cached for the session are alive. + auto opseg = unit->device->op_segment(); + opseg->AddHold(session); + + // Function library runtime. + unit->lib = NewFunctionLibraryRuntime( + unit->device, runner, def->versions().producer(), item->lib_def, + graph_options.optimizer_options()); + + // Construct the root executor for the subgraph. + params.device = unit->device; + auto lib = unit->lib; + params.function_library = lib; + params.create_kernel = [session, lib, opseg](const NodeDef& ndef, + OpKernel** kernel) { + // Caches the kernel only if the node is stateful. + if (!lib->IsStateful(ndef.op())) { + return lib->CreateKernel(ndef, kernel); + } + auto create_fn = [lib, &ndef](OpKernel** kernel) { + return lib->CreateKernel(ndef, kernel); + }; + // Kernels created for subgraph nodes need to be cached. On + // cache miss, create_fn() is invoked to create a kernel based + // on the function library here + global op registry. + return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn); + }; + params.delete_kernel = [lib](OpKernel* kernel) { + // If the node is stateful, opseg owns it. Otherwise, delete it. + if (kernel && !lib->IsStateful(kernel->type_string())) { + delete kernel; + } + }; + + optimizer.Optimize(lib, &subgraph); + s = ValidateMemoryTypes(DeviceType(unit->device->device_type()), subgraph); + if (!s.ok()) { + delete subgraph; + break; + } + s = NewLocalExecutor(params, subgraph, &unit->root); + if (!s.ok()) { + break; + } + } + return s; +} + +Status GraphMgr::Register(const string& session, const GraphDef& gdef, + const GraphOptions& graph_options, string* handle) { + Item* item = new Item; + Status s = InitItem(session, gdef, graph_options, item); + if (!s.ok()) { + item->Unref(); + return s; + } + + // Inserts one item into table_. + { + mutex_lock l(mu_); + *handle = strings::Printf("%016llx", ++next_id_); + item->handle = *handle; + CHECK(table_.insert({*handle, item}).second); + } + return Status::OK(); +} + +Status GraphMgr::Deregister(const string& handle) { + Item* item = nullptr; + // Removes one item from table_. + { + mutex_lock l(mu_); + auto iter = table_.find(handle); + if (iter == table_.end()) { + return errors::Aborted("Graph handle is not found: ", handle, + ". Possibly, this worker just restarted."); + } + item = iter->second; + table_.erase(iter); + } + item->Unref(); + return Status::OK(); +} + +Status GraphMgr::DeregisterAll() { + std::vector items; + // Removes all items from table_. + { + mutex_lock l(mu_); + for (const auto& entry : table_) { + items.push_back(entry.second); + } + table_.clear(); + } + for (auto item : items) { + item->Unref(); + } + return Status::OK(); +} + +Status GraphMgr::Execute(const string& handle, const int64 step_id, + const ExecutorOpts& opts, + StepStatsCollector* collector, + CancellationManager* cancellation_manager, + const NamedTensors& in, NamedTensors* out) { + Notification n; + Status status; + ExecuteAsync(handle, step_id, opts, collector, cancellation_manager, in, out, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + return status; +} + +void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, + const ExecutorOpts& opts, + StepStatsCollector* collector, + CancellationManager* cancellation_manager, + const NamedTensors& in, NamedTensors* out, + StatusCallback done) { + // Lookup an item. Holds one ref while executing. + Item* item = nullptr; + { + mutex_lock l(mu_); + auto iter = table_.find(handle); + if (iter != table_.end()) { + item = iter->second; + item->Ref(); + } + } + + if (item == nullptr) { + done(errors::Aborted("Graph handle is not found: ", handle)); + return; + } + + const int num_units = item->units.size(); + CHECK_GE(num_units, 1); + + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); + + // Sends values specified by the caller. + for (const auto& p : in) { + const string& key = p.first; + const Tensor& val = p.second; + const Status s = rendezvous->Send(key, Rendezvous::Args(), val, false); + if (!s.ok()) { + done(s); + item->Unref(); + rendezvous->Unref(); + return; + } + } + + // Starts parallel Executors. + // + // NOTE: Transfer one ref of rendezvous and one ref of item to + // RunAllDone. + ExecutorBarrier* barrier = new ExecutorBarrier( + num_units, rendezvous, std::bind(&ME::RunAllDone, this, item, rendezvous, + out, done, std::placeholders::_1)); + Executor::Args args; + { + mutex_lock l(mu_); + args.step_id = ++next_id_; + } + args.rendezvous = rendezvous; + args.cancellation_manager = cancellation_manager; + args.stats_collector = collector; + VLOG(1) << "Step " << args.step_id << " is for handle " << handle + << ", graph-local step " << step_id; + thread::ThreadPool* pool = worker_env_->compute_pool; + args.runner = [pool](std::function fn) { pool->Schedule(fn); }; + for (const auto& unit : item->units) { + unit.root->RunAsync(args, barrier->Get()); + } +} + +void GraphMgr::RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out, + StatusCallback done, Status s) { + if (s.ok()) { + // Receives values requested by the caller. + for (auto& p : *out) { + const string& key = p.first; + Tensor* val = &p.second; + bool is_dead = false; + s = rendezvous->Recv(key, Rendezvous::Args(), val, &is_dead); + if (is_dead) { + s = errors::InvalidArgument("The tensor returned for ", key, + " was not valid."); + } + if (!s.ok()) break; + } + } + done(s); + rendezvous->Unref(); + item->Unref(); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h new file mode 100644 index 00000000000..4300dbe3053 --- /dev/null +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -0,0 +1,147 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/config.pb.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class ExecutorOpts; +class StepStatsCollector; + +// GraphMgr keeps track of a set of graphs that are registered with a +// TensorFlow worker. Each registered graph is identified by a handle +// that is generated by GraphMgr and returned to the caller. +// +// After a successful registration, the caller executes a graph using +// the graph handle. Each execution is distinguished from others by a +// caller generated global unique id "step_id". Multiple executions +// can use the same graph concurrently and independently as long as +// "step_id" used are different. +// +// Multiple threads can call GraphMgr methods concurrently. +// +// E.g., +// GraphMgr gmgr(worker_env); +// string handle; +// TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b }, +// &handle)); +// GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) }, +// { "b", Tensor({3, 4}) } }; +// GraphMgr::NamedTensors out = { { "c", Tensor() } }; +// TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out)); +// EXPECT_EQ(out["c"], Tensor({4, 6})); +class GraphMgr { + public: + explicit GraphMgr(const WorkerEnv* worker_env); + ~GraphMgr(); + + // Registers a graph. Fills in "handle" + Status Register(const string& session, const GraphDef& gdef, + const GraphOptions& graph_options, string* handle); + + // Executes one step of a registered graph "handle". + // + // If "out" is not nullptr, "out" specifies all keys the execution + // should receive upon finish. + typedef std::map NamedTensors; + typedef std::function StatusCallback; + void ExecuteAsync(const string& handle, const int64 step_id, + const ExecutorOpts& opts, StepStatsCollector* collector, + CancellationManager* cancellation_manager, + const NamedTensors& in, NamedTensors* out, + StatusCallback done); + + // Synchronous wrapper. + Status Execute(const string& handle, const int64 step_id, + const ExecutorOpts& opts, + StepStatsCollector* step_stats_collector, + CancellationManager* cancellation_manager, + const NamedTensors& in, NamedTensors* out); + + // Deregisters a graph. + Status Deregister(const string& handle); + + // Deregister all graphs. + Status DeregisterAll(); + + private: + typedef GraphMgr ME; + + struct ExecutionUnit { + Device* device = nullptr; + Executor* root = nullptr; + FunctionLibraryRuntime* lib = nullptr; + }; + + struct Item : public core::RefCounted { + // TOOD(zhifengc): Keeps a copy of the original graph if the need arises. + // TOOD(zhifengc): Stats, updated by multiple runs potentially. + // TOOD(zhifengc): Dup-detection. Ensure step_id only run once. + ~Item() override; + + // Session handle. + string session; + + // Graph handle. + string handle; + + // The definition of the library is shared by all partitions. + FunctionLibraryDefinition* lib_def = nullptr; + + // A graph is partitioned over multiple devices. Each partition + // has a root executor which may call into the runtime library. + std::vector units; + }; + + // Not owned. + const WorkerEnv* worker_env_; + + // Owned. + mutex mu_; + int64 next_id_ GUARDED_BY(mu_) = 0; + + // Table mapping graph handles to registered graphs. + // + // TODO(zhifengc): If the client does not call Deregister, we'll + // lose memory over time. We should implement a timeout-based + // mechanism to gc these graphs. + std::unordered_map table_; + + void RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out, + StatusCallback done, Status run_status); + + Status InitItem(const string& session, const GraphDef& gdef, + const GraphOptions& graph_options, Item* item); + + TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc new file mode 100644 index 00000000000..2e8d6a18783 --- /dev/null +++ b/tensorflow/core/distributed_runtime/master.cc @@ -0,0 +1,413 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Master implements the service MasterSerivce. +// +// A Master maintains the state of live graph computation +// sessions, each session orchestrates both local and remote devices +// to carry out the graph computation. +// +// A Master knows ahead of time local devices available as +// client devices. +// +// A Master discovers remote devices on-demand and keeps track of +// statistics of those remote devices. +// +// Each session analyses the graph, places nodes across available +// devices, and ultimately drives the graph computation by initiating +// RunGraph on the workers. + +#include "tensorflow/core/distributed_runtime/master.h" + +#include +#include + +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/remote_device.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/protobuf/worker.pb.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +Master::Master(MasterEnv* env, double session_gc_seconds) + : env_(env), + last_1000_steps_(1000), + step_count_(0), + session_gc_seconds_(session_gc_seconds) { + // Right now, a master service must be co-located with a device. + // Otherwise, fetches do not work. + CHECK(!env->local_devices.empty()); + + if (session_gc_seconds_ > 0.0) { + SchedClosure([this]() { GC(); }); + } +} + +Master::~Master() { + { + mutex_lock l(mu_); + shutdown_ = true; + shutdown_cv_.notify_all(); + } + gc_stopped_.WaitForNotification(); +} + +void Master::GC() { + Env* env = Env::Default(); + while (true) { + mutex_lock l(mu_); + const int kTimeoutMilliseconds = 10 * 1000; // 10 seconds. + WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds); + if (shutdown_) { + break; + } + std::vector handles; + const int64 num_micros = static_cast(session_gc_seconds_ * 1000000); + for (const auto& entry : sessions_) { + auto lat = entry.second->last_access_time_usec(); + if (env->NowMicros() - lat > num_micros) { + handles.push_back(entry.first); + auto* sess = entry.second; + SchedClosure([this, sess]() { + LOG(WARNING) << "GC session " << sess->handle() << " after " + << session_gc_seconds_ << " seconds. " + << "Note that if you are starting multiple replicas " + << "on a staggered delay, session_gc_seconds may need " + << "to be raised."; + sess->Close(); + }); + } + } + for (const auto& handle : handles) sessions_.erase(handle); + } + gc_stopped_.Notify(); +} + +class DeviceFinder { + public: + explicit DeviceFinder( + const protobuf::RepeatedPtrField& device_filters, MasterEnv* env) + : env_(env) { + auto process_filter = [this](const string& filter) { + DeviceNameUtils::ParsedName parsed; + if (DeviceNameUtils::ParseFullName(filter, &parsed)) { + filters_.push_back(parsed); + } else { + LOG(FATAL) << "Skipping invalid filter: " << filter; + } + }; + for (const string& filter : device_filters) { + process_filter(filter); + } + } + + ~DeviceFinder() { + for (Device* dev : found_) delete dev; + } + + void Start() { + // Enumerates all known workers' target. A target name is a + // prefix of a device name. E.g., /job:mnist/replica:0/task:10. + std::vector workers; + env_->worker_cache->ListWorkers(&workers); + std::vector targets; + if (filters_.empty()) { + swap(workers, targets); + } else { + for (const string& name : workers) { + if (MatchFilters(name)) { + targets.push_back(name); + } + } + } + { + mutex_lock l(mu_); + num_pending_ = targets.size(); + if (num_pending_ == 0) { + pending_zero_.notify_all(); + } + } + // Talk to all workers to get the list of available devices. + using std::placeholders::_1; + using std::placeholders::_2; + for (size_t i = 0; i < targets.size(); ++i) { + NewRemoteDevices(env_->env, env_->worker_cache, targets[i], + std::bind(&ME::WhenFound, this, _1, _2)); + } + } + + void Wait() { + mutex_lock l(mu_); + while (num_pending_ != 0) { + pending_zero_.wait(l); + } + } + + // The caller takes the ownership of returned remote devices. + void GetRemoteDevices(const std::vector& local, + std::vector* remote) { + std::unordered_set names(local.size()); + for (Device* dev : local) names.insert(dev->name()); + mutex_lock l(mu_); + for (Device* dev : found_) { + const string& name = dev->name(); + if (names.insert(name).second && MatchFilters(name)) { + remote->push_back(dev); + } else { + delete dev; + } + } + found_.clear(); + } + + private: + typedef DeviceFinder ME; + const MasterEnv* env_; + std::vector filters_; + + mutex mu_; + int num_pending_ GUARDED_BY(mu_); + condition_variable pending_zero_; + std::vector found_ GUARDED_BY(mu_); + + void WhenFound(const Status& s, std::vector* devices) { + mutex_lock l(mu_); + if (!s.ok()) { + LOG(ERROR) << "Master init: " << s; + } else { + found_.insert(found_.end(), devices->begin(), devices->end()); + devices->clear(); + } + --num_pending_; + if (num_pending_ == 0) { + pending_zero_.notify_all(); + } + } + + // Returns true iff the set of devices allowed by 'x' intersects + // with the set of devices allowed by 'y'. + bool Intersects(const DeviceNameUtils::ParsedName& x, + const DeviceNameUtils::ParsedName& y) { + return (!x.has_job || !y.has_job || x.job == y.job) && + (!x.has_replica || !y.has_replica || x.replica == y.replica) && + (!x.has_task || !y.has_task || x.task == y.task) && + (!x.has_type || !y.has_type || x.type == y.type) && + (!x.has_id || !y.has_id || x.id == y.id); + } + + // Returns true iff 'name' matches one of the filters_. + bool MatchFilters(const string& name) { + if (filters_.empty()) return true; + DeviceNameUtils::ParsedName x; + if (DeviceNameUtils::ParseFullName(name, &x)) { + for (const auto& filter : filters_) { + if (Intersects(x, filter)) return true; + } + } + return false; + } + + TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder); +}; + +void Master::CreateSession(const CreateSessionRequest* req, + CreateSessionResponse* resp, MyClosure done) { + SchedClosure([this, req, resp, done]() { + Status status = ValidateExternalGraphDefSyntax(req->graph_def()); + if (status.ok()) { + // Ping all the workers and build the list of devices that the + // session will use. + DeviceFinder finder(req->config().device_filters(), env_); + finder.Start(); + finder.Wait(); + std::vector remote_devices; + finder.GetRemoteDevices(env_->local_devices, &remote_devices); + SessionOptions options; + options.config = req->config(); + MasterSessionInterface* session = + env_->master_session_factory(options, env_, &remote_devices); + GraphDef* gdef = + const_cast(req)->mutable_graph_def(); + Status create_status = session->Create(gdef); + if (!create_status.ok()) { + done(create_status); + return; + } + resp->set_session_handle(session->handle()); + // Insert into the session map. + { + mutex_lock l(mu_); + CHECK(sessions_.insert({session->handle(), session}).second); + } + } + done(status); + }); +} + +void Master::ExtendSession(const ExtendSessionRequest* req, + ExtendSessionResponse* resp, MyClosure done) { + mu_.lock(); + MasterSessionInterface* session = nullptr; + session = gtl::FindPtrOrNull(sessions_, req->session_handle()); + if (session == nullptr) { + mu_.unlock(); + done(errors::Aborted("Session ", req->session_handle(), " is not found.")); + return; + } + + SchedClosure([session, req, resp, done]() { + Status status = ValidateExternalGraphDefSyntax(req->graph_def()); + if (status.ok()) { + status = session->Extend(req, resp); + } + done(status); + }); + mu_.unlock(); +} + +void Master::RunStep(CallOptions* opts, const RunStepRequest* req, + RunStepResponse* resp, MyClosure done) { + mu_.lock(); + uint64 start_time = env_->env->NowMicros(); + MasterSessionInterface* session = + gtl::FindPtrOrNull(sessions_, req->session_handle()); + if (session == nullptr) { + mu_.unlock(); + done(errors::Aborted("Session ", req->session_handle(), " is not found.")); + return; + } + + SchedClosure([this, start_time, session, opts, req, resp, done]() { + Status status = session->Run(opts, req, resp); + uint64 done_time = env_->env->NowMicros(); + done(status); + mutex_lock l(mu_); + last_1000_steps_.AddValue((done_time - start_time) / 1e9); + ++step_count_; + }); + mu_.unlock(); +} + +void Master::CloseSession(const CloseSessionRequest* req, + CloseSessionResponse* resp, MyClosure done) { + MasterSessionInterface* session = nullptr; + { + mu_.lock(); + auto iter = sessions_.find(req->session_handle()); + if (iter == sessions_.end()) { + mu_.unlock(); + done(errors::Aborted( + "Session ", req->session_handle(), + " is not found. Possibly, this master has restarted.")); + return; + } + session = iter->second; + sessions_.erase(iter); + mu_.unlock(); + } + + // Session Close() blocks on thread shutdown. Therefore, we need to + // delete it in non-critical thread. + SchedClosure([session, done]() { + Status s = session->Close(); + done(s); + }); +} + +void Master::ListDevices(const ListDevicesRequest* req, + ListDevicesResponse* resp, MyClosure done) { + SchedClosure([this, req, resp, done]() { + DeviceFinder finder({}, env_); + finder.Start(); + finder.Wait(); + std::vector remote_devices; + finder.GetRemoteDevices(env_->local_devices, &remote_devices); + for (Device* dev : env_->local_devices) { + *(resp->add_local_device()) = dev->attributes(); + } + for (Device* dev : remote_devices) { + *(resp->add_remote_device()) = dev->attributes(); + delete dev; + } + done(Status::OK()); + }); +} + +void Master::CleanupWorkers(const ResetRequest& reset) { + std::vector worker_names; + env_->worker_cache->ListWorkers(&worker_names); + if (!worker_names.empty()) { + const int num_workers = worker_names.size(); + std::vector n(num_workers); + CleanupAllRequest req; + (*req.mutable_container()) = reset.container(); + std::vector resp(num_workers); + int c = 0; + for (int i = 0; i < num_workers; ++i) { + auto worker = env_->worker_cache->CreateWorker(worker_names[i]); + if (worker) { + worker->CleanupAllAsync(&req, &resp[i], [&n, worker, c](Status s) { + TF_CHECK_OK(s); + delete worker; + n[c].Notify(); + }); + } else { + n[c].Notify(); + } + ++c; + } + for (int i = 0; i < n.size(); ++i) { + n[i].WaitForNotification(); + } + } +} + +void Master::Reset(const ResetRequest* req, ResetResponse* resp, + MyClosure done) { + // Vector to hold the session pointers present in the sessions_ + // (string->Session*) map. + std::vector sessions; + { + mutex_lock l(mu_); + for (const auto& entry : sessions_) { + sessions.push_back(entry.second); + } + sessions_.clear(); + } + + CleanupWorkers(*req); + + SchedClosure([sessions, done]() { + Status s; + for (MasterSessionInterface* session : sessions) { + s.Update(session->Close()); + } + done(s); + }); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master.h b/tensorflow/core/distributed_runtime/master.h new file mode 100644 index 00000000000..16e2c1a8664 --- /dev/null +++ b/tensorflow/core/distributed_runtime/master.h @@ -0,0 +1,98 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_ + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/master_env.h" +#include "tensorflow/core/distributed_runtime/master_session_interface.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +class Master { + public: + explicit Master(MasterEnv* env, double session_gc_seconds); + virtual ~Master(); + + // Convenient typedef for a closure passing a Status. + typedef std::function MyClosure; + + void CreateSession(const CreateSessionRequest* req, + CreateSessionResponse* resp, MyClosure done); + + void ExtendSession(const ExtendSessionRequest* req, + ExtendSessionResponse* resp, MyClosure done); + + void RunStep(CallOptions* opts, const RunStepRequest* req, + RunStepResponse* resp, MyClosure done); + + void CloseSession(const CloseSessionRequest* req, CloseSessionResponse* resp, + MyClosure done); + + void ListDevices(const ListDevicesRequest* req, ListDevicesResponse* resp, + MyClosure done); + + void Reset(const ResetRequest* req, ResetResponse* resp, MyClosure done); + + private: + typedef Master ME; + + // Not owned. + MasterEnv* env_ = nullptr; + + // Owned. + mutex mu_; + + // shutdown_ is set to true by the dtor. + condition_variable shutdown_cv_; + bool shutdown_ GUARDED_BY(mu_) = false; + Notification gc_stopped_; + + // Maps session handles to sessions. + std::unordered_map sessions_ GUARDED_BY(mu_); + + // Moving average of step times. + MovingAverage last_1000_steps_ GUARDED_BY(mu_); + + // Cumulative number of steps executed. + int64 step_count_ GUARDED_BY(mu_); + + // If a session is not active for this many seconds, it will be + // closed automatically. + const double session_gc_seconds_; + + // Call CleanupAll on all workers. + void CleanupWorkers(const ResetRequest& reset); + + // Cleanup unused session. + void GC(); + + TF_DISALLOW_COPY_AND_ASSIGN(Master); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_ diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h new file mode 100644 index 00000000000..513442b7e6d --- /dev/null +++ b/tensorflow/core/distributed_runtime/master_env.h @@ -0,0 +1,66 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_ + +#include +#include + +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class Device; +class Env; +class MasterSessionInterface; +class OpRegistryInterface; +class WorkerCacheInterface; + +// The master environment class, which holds a bag of pointers to +// per-master state. +// +// MasterEnv does not own its member pointers. +struct MasterEnv { + Env* env = nullptr; + + // Object from which WorkerInterface instances can be obtained. + WorkerCacheInterface* worker_cache = nullptr; + + // The operation definitions to use. Must be filled before use. + const OpRegistryInterface* ops = nullptr; + + // Local devices co-located with this master. Devices are not owned + // by the master service. + // + // REQUIRES: !local_devices.empty(). + std::vector local_devices; + + // Factory for creating master sessions, given session options and a + // vector of devices. + // + // The caller of the function takes ownership of the returned + // `MasterSessionInterface`, which may not be null. Ownership of the + // `MasterEnv*` is retained by the caller. The callee takes + // ownership of the `std::vector*` argument, but does not + // take ownership of the `Device*` objects in the vector. + std::function*)> + master_session_factory; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_ diff --git a/tensorflow/core/distributed_runtime/master_interface.h b/tensorflow/core/distributed_runtime/master_interface.h new file mode 100644 index 00000000000..602cfbd8a30 --- /dev/null +++ b/tensorflow/core/distributed_runtime/master_interface.h @@ -0,0 +1,52 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/master.pb.h" + +namespace tensorflow { + +// Pure virtual interface for communicating with the TensorFlow Master service. +// +// This interface is intended to support in-process master +// implementations that do not require an RPC roundtrip. +class MasterInterface { + public: + virtual ~MasterInterface() {} + virtual Status CreateSession(const CreateSessionRequest* request, + CreateSessionResponse* response) = 0; + + virtual Status ExtendSession(const ExtendSessionRequest* request, + ExtendSessionResponse* response) = 0; + + virtual Status RunStep(const RunStepRequest* request, + RunStepResponse* response) = 0; + + virtual Status CloseSession(const CloseSessionRequest* request, + CloseSessionResponse* response) = 0; + + virtual Status ListDevices(const ListDevicesRequest* request, + ListDevicesResponse* response) = 0; + + virtual Status Reset(const ResetRequest* request, + ResetResponse* response) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_ diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc new file mode 100644 index 00000000000..9535f5db473 --- /dev/null +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -0,0 +1,942 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/master_session.h" + +#include +#include +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/distributed_runtime/master_env.h" +#include "tensorflow/core/distributed_runtime/master_session_interface.h" +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/simple_graph_execution_state.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph_partition.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +namespace { +// A little bit of per-step state. +struct PerStepState { + Microseconds start_micros = Microseconds(0); + Microseconds end_micros = Microseconds(0); + std::vector step_stats; // per partition +}; + +// A session encapsulates a graph computation (resource allocation, +// placement, execution, etc.). +class MasterSession : public MasterSessionInterface { + public: + // This session encapsulates the graph computation for a graph. + // + // The session places nodes on devices in "remote_devs" and executes + // operations on these devices. + // + // The caller takes ownership of all remote devices. + MasterSession(const SessionOptions& options, const MasterEnv* env, + std::vector* remote_devs); + + // Initialize the Session for "def". Must be called before Extend(), + // Run(), or Close(). + // + // The callee may clear "def". + Status Create(GraphDef* def) override; + + // Returns the session handle. + const string& handle() const override { return handle_; } + + // Returns the last access time (the number of micro-seconds since + // some fixed point in time) of this session. + uint64 last_access_time_usec() const override { + return last_access_time_usec_.load(); + } + + // Attempt to extend the graph according to the given "req". + // (See master.proto for details of valid extensions.) + // + // PRECONDITION: The current version of this session's graph + // is "req->current_graph_version". + // + // POSTCONDITION: The current version of this session's graph + // is "resp->new_graph_version". + // + // Extend() may block the caller thread for a long time. + Status Extend(const ExtendSessionRequest* req, + ExtendSessionResponse* resp) override; + + // Run one step. + Status Run(CallOptions* opts, const RunStepRequest* req, + RunStepResponse* resp) override; + + // Close this session and delete "*this". Returns OK if all known + // states are cleanup successfully. + // + // Close() may block the caller thread for a long time. + Status Close() override; + + private: + SessionOptions session_opts_; + + // Not owned. + const MasterEnv* env_; + + // The opaque session handle. + const string handle_; + + // Owned. + std::vector remote_devs_; + + // The device set used by this session. + DeviceSet devices_; + + // TODO(zhifengc): Support Extend(). + // + // 'func_def_lib_' is a copy of the initial graph def's library. + // 'flib_def_' is an index structure of "func_def_lib_' keyed by + // function names. + FunctionDefLibrary func_def_lib_; + FunctionLibraryDefinition* flib_def_ = nullptr; + + std::atomic_ulong last_access_time_usec_; + + mutex mu_; + std::unique_ptr execution_state_; + int64 graph_version_; + + int32 steps_since_last_scheduling_ GUARDED_BY(mu_) = 0; + int32 scheduling_period_steps_ GUARDED_BY(mu_) = 10; + + // We keep a map from a signature of a run request to the + // ReffedClientGraph the can execute it. We keep up to one old copy + // of each ReffedClientGraph around because if it gets deallocated + // before a new substitute has been created, Variables can go out of + // scope and lose their state. + class ReffedClientGraph; + typedef std::unordered_map RCGMap; + RCGMap runs_ GUARDED_BY(mu_); + RCGMap obsolete_ GUARDED_BY(mu_); + + // Active RunStep calls. + condition_variable num_running_is_zero_; + int32 num_running_ GUARDED_BY(mu_) = 0; + + std::unordered_map subgraph_execution_counts_ GUARDED_BY(mu_); + + // We need to ensure that certain nodes added (e.g., send and recv + // nodes) are unique across all sub-graphs within this session. + int64 next_node_id_ GUARDED_BY(mu_) = 0; + + // Private dtor. The client must call Close(). + virtual ~MasterSession(); + + Status StartStep(const RunStepRequest& req, BuildGraphOptions* opts, + int64* count, ReffedClientGraph** graph); + void ClearRunsTable(std::vector* to_unref, + RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status DoRunWithLocalExecution(CallOptions* opts, const RunStepRequest* req, + RunStepResponse* resp); + void UpdateLastAccessTime(); + + TF_DISALLOW_COPY_AND_ASSIGN(MasterSession); +}; + +// Session wraps ClientGraph in a reference counted object. This way, +// Session can clear up the cache mapping Run requests to compiled +// graphs while the compiled graph is still being used. +// +// TODO(zhifengc): Cleanup this class. It's becoming messy. +class MasterSession::ReffedClientGraph : public core::RefCounted { + public: + ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts, + ClientGraph* cg, const GraphOptions& graph_opts) + : session_handle_(handle), + client_graph_(cg), + bopts_(bopts), + graph_opts_(graph_opts) { + VLOG(1) << "Created ReffedClientGraph for node with " + << client_graph_->graph.num_node_ids(); + + const string key = + strings::StrCat("{", str_util::Join(bopts.feed_endpoints, ","), "},{", + str_util::Join(bopts.target_nodes, ","), "},{", + str_util::Join(bopts.fetch_endpoints, ","), "}"); + // TODO(mrry): Publish information about the graph (such as + // timelines, the pruned graph, statistics, etc.). + } + + ~ReffedClientGraph() override { + delete client_graph_; + DeregisterPartitions(); + } + + const ClientGraph* client_graph() { return client_graph_; } + + // Local execution methods. + + // Partitions the graph into subgraphs and registers them on + // workers. + Status RegisterPartitions(const MasterEnv* env, const PartitionOptions& popts, + const FunctionDefLibrary& func_def_lib); + + // Runs one step of all partitions. + Status RunPartitions(const MasterEnv* env, int64 step_id, + int64 execution_count, + SimpleGraphExecutionState* execution_state, + PerStepState* pss, CallOptions* opts, + const RunStepRequest& req, RunStepResponse* resp); + + // Calls workers to cleanup states for the step "step_id". Waits + // till all cleanup rpcs complete. + Status CleanupPartitions(int64 step_id); + + // TODO(mrry): Runtime statistics collection. + + private: + const string session_handle_; + ClientGraph* const client_graph_ = nullptr; + std::unordered_set nodes_needing_input_mapping_; + BuildGraphOptions bopts_; + const GraphOptions graph_opts_; + + // Graph partitioned into per-location subgraphs. + struct Part { + // Worker name. + string name; + + // Graph definition. + GraphDef gdef; + + // Maps feed names to rendezvous keys. Empty most of the time. + std::unordered_map feed_key; + + // Maps rendezvous keys to fetch names. Empty most of the time. + std::unordered_map key_fetch; + + // The interface to the worker. Owned. + WorkerInterface* worker = nullptr; + + // After registeration with the worker, graph_handle identifies + // this partition on the worker. + string graph_handle; + + Part() : feed_key(3), key_fetch(3) {} + }; + + // partitions_ is immutable after RegisterPartitions() call + // finishes. RunPartitions() can access partitions_ safely without + // acquring locks. + std::vector partitions_; + + mutable mutex mu_; + + // Partition initialization and registration only needs to happen + // once. init_started_ && !init_done_ indicates the initialization + // is on going. + bool init_started_ GUARDED_BY(mu_) = false; + Notification init_done_; + + // init_result_ remembers the initialization error if any. + Status init_result_ GUARDED_BY(mu_); + + // Send/Recv nodes that are the result of client-added + // feeds and fetches must be tracked so that the tensors + // can be be added to the local rendezvous. + static void TrackFeedsAndFetches(Part* part, const PartitionOptions& popts); + + // The actual graph partitioning and registration implementation. + Status DoRegisterPartitions(const MasterEnv* env, + const PartitionOptions& popts, + const FunctionDefLibrary& func_def_lib); + + // Deregisters the partitions on the workers. Called in the + // destructor and does not wait for the rpc completion. + void DeregisterPartitions(); + + TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph); +}; + +Status MasterSession::ReffedClientGraph::RegisterPartitions( + const MasterEnv* env, const PartitionOptions& popts, + const FunctionDefLibrary& func_def_lib) { + { // Ensure register once. + mu_.lock(); + if (!init_started_) { + init_started_ = true; + mu_.unlock(); + Status s = DoRegisterPartitions(env, popts, func_def_lib); + mu_.lock(); + init_result_ = s; + init_done_.Notify(); + } else { + mu_.unlock(); + init_done_.WaitForNotification(); + mu_.lock(); + } + Status result = init_result_; + mu_.unlock(); + return result; + } +} + +static string SplitByWorker(const Node* node) { + string task; + string device; + CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task, + &device)) + << "node: " << node->name() << " dev: " << node->assigned_device_name(); + return task; +} + +void MasterSession::ReffedClientGraph::TrackFeedsAndFetches( + Part* part, const PartitionOptions& popts) { + for (int i = 0; i < part->gdef.node_size(); ++i) { + NodeDef* ndef = part->gdef.mutable_node(i); + const bool is_recv = ndef->op() == "_Recv"; + const bool is_send = ndef->op() == "_Send"; + + if (is_recv || is_send) { + string name; + TF_CHECK_OK(GetNodeAttr(*ndef, "tensor_name", &name)); + string send_device; + TF_CHECK_OK(GetNodeAttr(*ndef, "send_device", &send_device)); + string recv_device; + TF_CHECK_OK(GetNodeAttr(*ndef, "recv_device", &recv_device)); + uint64 send_device_incarnation; + TF_CHECK_OK( + GetNodeAttr(*ndef, "send_device_incarnation", + reinterpret_cast(&send_device_incarnation))); + const string& key = + Rendezvous::CreateKey(send_device, send_device_incarnation, + recv_device, name, FrameAndIter(0, 0)); + + // Only send/recv nodes that were added as feeds and fetches + // (client-terminated) should be tracked. Other send/recv nodes + // are for transferring data between partitions / memory spaces. + bool client_terminated; + TF_CHECK_OK(GetNodeAttr(*ndef, "client_terminated", &client_terminated)); + if (client_terminated) { + if (is_recv) { + part->feed_key.insert({name, key}); + } else { + part->key_fetch.insert({key, name}); + } + } + } + } +} + +Status MasterSession::ReffedClientGraph::DoRegisterPartitions( + const MasterEnv* env, const PartitionOptions& popts, + const FunctionDefLibrary& func_def_lib) { + // Partition the graph. + Status s; + std::unordered_map graph_partitions; + s = Partition(popts, &client_graph_->graph, &graph_partitions); + if (!s.ok()) return s; + partitions_.reserve(graph_partitions.size()); + for (auto& name_def : graph_partitions) { + partitions_.resize(partitions_.size() + 1); + Part* part = &partitions_.back(); + part->name = name_def.first; + part->gdef.Swap(&name_def.second); + // For simplicity, we ship the library completely to every worker. + *(part->gdef.mutable_library()) = func_def_lib; + TrackFeedsAndFetches(part, popts); + part->worker = env->worker_cache->CreateWorker(part->name); + if (part->worker == nullptr) { + s = errors::NotFound("worker ", part->name); + break; + } + } + if (!s.ok()) { + for (Part& part : partitions_) { + delete part.worker; + } + return s; + } + struct Call { + RegisterGraphRequest req; + RegisterGraphResponse resp; + Status status; + Notification done; + }; + const int num = partitions_.size(); + gtl::InlinedVector calls(num); + for (int i = 0; i < num; ++i) { + const Part& part = partitions_[i]; + Call* c = &calls[i]; + c->req.set_session_handle(session_handle_); + *c->req.mutable_graph_def() = part.gdef; + *c->req.mutable_graph_options() = graph_opts_; + VLOG(2) << "Register " << part.gdef.DebugString(); + auto cb = [c](const Status& s) { + c->status = s; + c->done.Notify(); + }; + part.worker->RegisterGraphAsync(&c->req, &c->resp, cb); + } + for (int i = num - 1; i >= 0; --i) { + Call* c = &calls[i]; + c->done.WaitForNotification(); + s.Update(c->status); + partitions_[i].graph_handle = c->resp.graph_handle(); + } + return s; +} + +static bool CopyIfNeeded(TensorProto* in, TensorProto* out) { + if (in->tensor_content().empty()) { + // If the tensor is not encoded in tensor_content or contains 0 + // elements, we can return it to the client directly. + out->Swap(in); + } else { + Tensor t(in->dtype()); + if (!t.FromProto(cpu_allocator(), *in)) return false; + t.AsProtoField(out); + } + return true; +} + +// Helper class to manage "num" parallel RunGraph calls. +class RunManyGraphs { + public: + explicit RunManyGraphs(int num) : calls_(num), num_pending_(num) {} + + ~RunManyGraphs() {} + + // Returns the index-th call. + struct Call { + CallOptions opts; + RunGraphRequest req; + RunGraphResponse resp; + }; + Call* get(int index) { return &calls_[index]; } + + // When the index-th call is done, updates the overall status. + void WhenDone(int index, const Status& s) { + TRACEPRINTF("Partition %d %s", index, s.ToString().c_str()); + { + mutex_lock l(mu_); + if (!s.ok()) { + UpdateStatusLocked(s); + } + --num_pending_; + cv_pending_.notify_all(); + } + } + + void StartCancel() { + mutex_lock l(mu_); + UpdateStatusLocked(errors::Cancelled("RunManyGraphs")); + } + + void Wait() { + mutex_lock l(mu_); + while (num_pending_ > 0) { + cv_pending_.wait(l); + } + } + + Status status() const { + mutex_lock l(mu_); + return status_; + } + + private: + gtl::InlinedVector calls_; + + // TODO(jeff,sanjay): Replace bookkeeping state here with a + // BlockingCounter abstraction that we define in + // tensorflow/core/lib/core. + mutable mutex mu_; + condition_variable cv_pending_; + int num_pending_; + Status status_ GUARDED_BY(mu_); + + void UpdateStatusLocked(const Status& s) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (status_.ok()) { + status_ = s; + for (Call& call : calls_) { + call.opts.StartCancel(); + } + } + } + + TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs); +}; + +Status MasterSession::ReffedClientGraph::RunPartitions( + const MasterEnv* env, int64 step_id, int64 execution_count, + SimpleGraphExecutionState* execution_state, PerStepState* pss, + CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp) { + VLOG(2) << "RunPartitions step_id " << step_id << " execution_count " + << execution_count; + // Builds an index for feeds provided by the client. + std::unordered_map + feeds(3); + + for (const auto& feed : req.feed()) { + if (!feeds.insert({feed.name(), &feed.tensor()}).second) { + return errors::InvalidArgument("Duplicated feeds: ", feed.name()); + } + } + + // Prepares a number of calls to workers. One call per partition. + ExecutorOpts exec_opts; + const int num = partitions_.size(); + RunManyGraphs calls(num); + + for (int i = 0; i < num; ++i) { + const Part& part = partitions_[i]; + RunManyGraphs::Call* c = calls.get(i); + c->req.set_graph_handle(part.graph_handle); + c->req.set_step_id(step_id); + *c->req.mutable_exec_opts() = exec_opts; + // If any feeds are provided, send the feed values together + // in the RunGraph request. + for (const auto& feed_key : part.feed_key) { + const string& feed = feed_key.first; + const string& key = feed_key.second; + const TensorProto* val = feeds[feed]; + if (val == nullptr) { + return errors::InvalidArgument("No feed is provided for feed=", feed, + ", key=", key); + } + auto* send = c->req.add_send(); + send->set_key(key); + *(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed. + } + for (const auto& key_fetch : part.key_fetch) { + const string& key = key_fetch.first; + c->req.add_recv_key(key); + } + } + + // Issues RunGraph calls. + for (int i = 0; i < num; ++i) { + const Part& part = partitions_[i]; + RunManyGraphs::Call* call = calls.get(i); + TRACEPRINTF("Partition %d %s", i, part.name.c_str()); + part.worker->RunGraphAsync( + &call->opts, &call->req, &call->resp, + std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1)); + } + + // Waits for the RunGraph calls. + call_opts->SetCancelCallback([&calls]() { calls.StartCancel(); }); + calls.Wait(); + call_opts->ClearCancelCallback(); + + // Collects fetches. + Status status = calls.status(); + if (status.ok()) { + for (int i = 0; i < num; ++i) { + const Part& part = partitions_[i]; + for (auto& recv : *(calls.get(i)->resp.mutable_recv())) { + auto* ret = resp->add_tensor(); + auto iter = part.key_fetch.find(recv.key()); + if (iter == part.key_fetch.end()) { + status.Update(errors::Internal("Unexpected fetch key: ", recv.key())); + break; + } + const string& fetch = iter->second; + ret->set_name(fetch); + if (!CopyIfNeeded(recv.mutable_val(), ret->mutable_tensor())) { + status.Update( + errors::Internal("Unexpected unparseable tensor: ", recv.key())); + break; + } + } + if (calls.get(i)->resp.has_step_stats()) { + pss->step_stats[i].Swap(calls.get(i)->resp.mutable_step_stats()); + } + } + } + return status; +} + +Status MasterSession::ReffedClientGraph::CleanupPartitions(int64 step_id) { + struct Call { + CleanupGraphRequest req; + CleanupGraphResponse resp; + Notification done; + Status status; + }; + const int num = partitions_.size(); + gtl::InlinedVector calls(num); + for (int i = 0; i < num; ++i) { + const Part& part = partitions_[i]; + Call* c = &calls[i]; + c->req.set_step_id(step_id); + part.worker->CleanupGraphAsync(&c->req, &c->resp, [c](const Status& s) { + c->status = s; + c->done.Notify(); + }); + } + Status s; + for (int i = num - 1; i >= 0; --i) { + Call* c = &calls[i]; + c->done.WaitForNotification(); + s.Update(c->status); + } + return s; +} + +// Makes async calls to workers without waiting deregistering subgraphs. +void MasterSession::ReffedClientGraph::DeregisterPartitions() { + struct Call { + DeregisterGraphRequest req; + DeregisterGraphResponse resp; + }; + for (Part& part : partitions_) { + Call* c = new Call; + c->req.set_graph_handle(part.graph_handle); + WorkerInterface* w = part.worker; + auto cb = [c, w](const Status& s) { + if (!s.ok()) { + // This error is potentially benign, so we don't log at the + // error level. + LOG(INFO) << "DeregisterGraph error: " << s; + } + delete c; + delete w; + }; + w->DeregisterGraphAsync(&c->req, &c->resp, cb); + } +} + +void BuildBuildGraphOptions(const RunStepRequest& req, + BuildGraphOptions* opts) { + for (const auto& feed : req.feed()) { + opts->feed_endpoints.push_back(feed.name()); + } + for (const auto& fetch : req.fetch()) { + // TODO(touts): handle ref: + opts->fetch_endpoints.push_back(fetch); + } + for (const auto& target : req.target()) { + opts->target_nodes.push_back(target); + } + + std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end()); + std::sort(opts->target_nodes.begin(), opts->target_nodes.end()); + std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end()); +} + +uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) { + uint64 h = 0x2b992ddfa23249d6ull; + for (const string& name : opts.feed_endpoints) { + h = Hash64(name.c_str(), name.size(), h); + } + for (const string& name : opts.target_nodes) { + h = Hash64(name.c_str(), name.size(), h); + } + for (const string& name : opts.fetch_endpoints) { + h = Hash64(name.c_str(), name.size(), h); + } + return h; +} + +string BuildGraphOptionsString(const BuildGraphOptions& opts) { + string buf; + for (const string& name : opts.feed_endpoints) { + strings::StrAppend(&buf, " FdE: ", name); + } + strings::StrAppend(&buf, "\n"); + for (const string& name : opts.target_nodes) { + strings::StrAppend(&buf, " TN: ", name); + } + strings::StrAppend(&buf, "\n"); + for (const string& name : opts.fetch_endpoints) { + strings::StrAppend(&buf, " FeE: ", name); + } + strings::StrAppend(&buf, "\n"); + return buf; +} + +MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env, + std::vector* remote_devs) + : session_opts_(opt), + env_(env), + handle_(strings::FpToString(random::New64())), + graph_version_(0), + runs_(5) { + UpdateLastAccessTime(); + + swap(remote_devs_, *remote_devs); + VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size() + << " #remote " << remote_devs_.size(); + for (Device* d : remote_devs_) { + devices_.AddDevice(d); + } + int num_local_devices = 0; + for (Device* d : env->local_devices) { + devices_.AddDevice(d); + if (num_local_devices == 0) { + // Uses the first local device as the client device. + devices_.set_client_device(d); + } + num_local_devices++; + } +} + +MasterSession::~MasterSession() { + for (const auto& iter : runs_) iter.second->Unref(); + for (const auto& iter : obsolete_) iter.second->Unref(); + delete flib_def_; + for (Device* dev : remote_devs_) delete dev; +} + +void MasterSession::UpdateLastAccessTime() { + last_access_time_usec_.store(Env::Default()->NowMicros()); +} + +Status MasterSession::Create(GraphDef* graph_def) { + // Keeps a copy of graph_def->library() and flib_def_ serves the + // OpRegistryInterface used by the SimpleGraphExecutionState to construct the + // pre-partitioned graphs during DoRunWithLocalExecution(). + func_def_lib_.Swap(graph_def->mutable_library()); + flib_def_ = new FunctionLibraryDefinition(func_def_lib_); + + SimpleGraphExecutionStateOptions options; + options.device_set = &devices_; + options.session_options = &session_opts_; + execution_state_.reset(new SimpleGraphExecutionState(flib_def_, options)); + TF_RETURN_IF_ERROR(execution_state_->Create(graph_def)); + + return Status::OK(); +} + +Status MasterSession::Extend(const ExtendSessionRequest* req, + ExtendSessionResponse* resp) { + UpdateLastAccessTime(); + std::unique_ptr old_execution_state; + { + mutex_lock l(mu_); + // TODO(mrry): Redesign the locking with reader/writer locks to prevent + // starvation due to concurrent steps being issued. This is not + // immediately important because we expect Extend to be used in + // development/interactive exploration, and not during high-throughput + // training. + while (num_running_ != 0) { + num_running_is_zero_.wait(l); + } + + if (graph_version_ != req->current_graph_version()) { + return errors::Aborted("Current version is ", graph_version_, + " but caller expected ", + req->current_graph_version(), "."); + } + + CHECK(execution_state_); + SimpleGraphExecutionState* extended_execution_state = nullptr; + Status s = + execution_state_->Extend(req->graph_def(), &extended_execution_state); + if (s.ok()) { + CHECK(extended_execution_state); + old_execution_state = + std::move(execution_state_); // Will be released outside the lock + execution_state_.reset(extended_execution_state); + ++graph_version_; + resp->set_new_graph_version(graph_version_); + } + + return s; + } +} + +Status MasterSession::StartStep(const RunStepRequest& req, + BuildGraphOptions* opts, int64* count, + ReffedClientGraph** rcg) { + BuildBuildGraphOptions(req, opts); + const uint64 hash = HashBuildGraphOptions(*opts); + ReffedClientGraph* to_unref = nullptr; + { + mutex_lock l(mu_); + // Keep track of how many times this subgraph has been executed in + // this session. + int64* c = &subgraph_execution_counts_[hash]; + *count = (*c)++; + auto iter = runs_.find(hash); + if (iter == runs_.end()) { + // We have not seen this subgraph before. Build the subgraph and + // cache it. + VLOG(1) << "Unseen hash " << hash << " for " + << BuildGraphOptionsString(*opts); + ClientGraph* client_graph = nullptr; + TF_RETURN_IF_ERROR(execution_state_->BuildGraph(*opts, &client_graph)); + auto entry = new ReffedClientGraph(handle_, *opts, client_graph, + session_opts_.config.graph_options()); + iter = runs_.insert({hash, entry}).first; + auto obs_iter = obsolete_.find(hash); + if (obs_iter != obsolete_.end()) { + to_unref = obs_iter->second; + obsolete_.erase(obs_iter); + } + VLOG(1) << "Preparing to execute new graph"; + } + *rcg = iter->second; + (*rcg)->Ref(); + } + if (to_unref) to_unref->Unref(); + return Status::OK(); +} + +void MasterSession::ClearRunsTable(std::vector* to_unref, + RCGMap* rcg_map) { + VLOG(1) << "Discarding all reffed graphs"; + for (auto p : *rcg_map) { + ReffedClientGraph* rcg = p.second; + if (to_unref) { + to_unref->push_back(rcg); + } else { + rcg->Unref(); + } + } + rcg_map->clear(); +} + +Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req, + RunStepResponse* resp) { + UpdateLastAccessTime(); + { + mutex_lock l(mu_); + ++num_running_; + } + Status status = DoRunWithLocalExecution(opts, req, resp); + { + mutex_lock l(mu_); + --num_running_; + if (num_running_ == 0) { + num_running_is_zero_.notify_all(); + } + } + return status; +} + +Status MasterSession::DoRunWithLocalExecution(CallOptions* opts, + const RunStepRequest* req, + RunStepResponse* resp) { + VLOG(2) << "DoRunWithLocalExecution " + << "req: " << req->DebugString(); + PerStepState pss; + pss.start_micros = Env::Default()->NowMicros(); + + // Prepare. + BuildGraphOptions bgopts; + ReffedClientGraph* rcg = nullptr; + int64 count = 0; + TF_RETURN_IF_ERROR(StartStep(*req, &bgopts, &count, &rcg)); + + // Unref "rcg" when out of scope. + core::ScopedUnref unref(rcg); + + // Registers subgraphs if haven't done so. + PartitionOptions popts; + popts.node_to_loc = SplitByWorker; + popts.new_name = [this](const string& prefix) { + mutex_lock l(mu_); + return strings::StrCat(prefix, "_S", next_node_id_++); + }; + popts.get_incarnation = [this](const string& name) { + Device* d = devices_.FindDeviceByName(name); + if (d == nullptr) { + return PartitionOptions::kIllegalIncarnation; + } else { + return d->attributes().incarnation(); + } + }; + popts.control_flow_added = false; + // TODO(mrry): Enable DT_BFLOAT16 casting. + // TODO(mrry): Enable recv scheduling. + TF_RETURN_IF_ERROR(rcg->RegisterPartitions(env_, popts, func_def_lib_)); + + // Keeps the highest 8 bits 0x01: we reserve some bits of the + // step_id for future use. + const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56); + TRACEPRINTF("stepid %llu", step_id); + + TF_RETURN_IF_ERROR(rcg->RunPartitions( + env_, step_id, count, execution_state_.get(), &pss, opts, *req, resp)); + + pss.end_micros = Env::Default()->NowMicros(); + + // Schedule post-processing and cleanup to be done async. + rcg->Ref(); + // TODO(tucker): We're doing the stats processing prior to returning + // the response to the client. Ensure it's safe to do so, then schedule + // in a closure. + SchedClosure([this, rcg, step_id]() { + Status s = rcg->CleanupPartitions(step_id); + if (!s.ok()) { + LOG(ERROR) << "Cleanup partition error: " << s; + } + rcg->Unref(); + }); + + return Status::OK(); +} + +Status MasterSession::Close() { + std::vector to_unref; + { + mutex_lock l(mu_); + while (num_running_ != 0) { + num_running_is_zero_.wait(l); + } + ClearRunsTable(&to_unref, &runs_); + ClearRunsTable(&to_unref, &obsolete_); + } + for (ReffedClientGraph* rcg : to_unref) rcg->Unref(); + delete this; + return Status::OK(); +} + +} // end namespace + +namespace internal { + +MasterSessionInterface* NewMasterSession(const SessionOptions& options, + const MasterEnv* env, + std::vector* remote_devs) { + return new MasterSession(options, env, remote_devs); +} + +} // end namespace internal +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h new file mode 100644 index 00000000000..dc24c5c6711 --- /dev/null +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -0,0 +1,38 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ + +#include + +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class Device; +class MasterEnv; +class MasterSessionInterface; + +namespace internal { + +MasterSessionInterface* NewMasterSession(const SessionOptions& options, + const MasterEnv* env, + std::vector* remote_devs); + +} // namespace internal +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ diff --git a/tensorflow/core/distributed_runtime/master_session_interface.h b/tensorflow/core/distributed_runtime/master_session_interface.h new file mode 100644 index 00000000000..9d6516bfc59 --- /dev/null +++ b/tensorflow/core/distributed_runtime/master_session_interface.h @@ -0,0 +1,76 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +class ThreadPool; + +namespace tensorflow { + +class CallOptions; +class GraphDef; +class RunStepRequest; +class RunStepResponse; +class ExtendSessionRequest; +class ExtendSessionResponse; + +// A "master session" encapsulates a distributed graph computation +// (resource allocation, placement, execution, etc.). +class MasterSessionInterface { + public: + // Initializes the Session with "def". Must be called before Extend(), + // Run(), or Close(). + // + // The callee may clear "def". + virtual Status Create(GraphDef* def) = 0; + + // Returns the session handle. + virtual const string& handle() const = 0; + + // Returns the last access time (the number of micro-seconds since + // some fixed point in time) of this session. + virtual uint64 last_access_time_usec() const = 0; + + // Attempt to extend the graph according to the given "req". + // (See master.proto for details of valid extensions.) + // + // PRECONDITION: The current version of this session's graph + // is "req->current_version". + // + // POSTCONDITION: The current version of this session's graph + // is "req->new_version". + // + // Extend() may block the caller thread for a long time. + virtual Status Extend(const ExtendSessionRequest* req, + ExtendSessionResponse* resp) = 0; + + // Run one step. + virtual Status Run(CallOptions* opts, const RunStepRequest* req, + RunStepResponse* resp) = 0; + + // Close this session and delete "*this". Returns OK if all known + // states are cleanup successfully. + // + // Close() may block the caller thread for a long time. + virtual Status Close() = 0; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_ diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc new file mode 100644 index 00000000000..a0a37081004 --- /dev/null +++ b/tensorflow/core/distributed_runtime/master_test.cc @@ -0,0 +1,423 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/master.h" + +#include +#include + +#include "external/grpc/include/grpc++/grpc++.h" + +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/protobuf/master_service.grpc.pb.h" + +namespace tensorflow { + +class MasterTest : public ::testing::Test { + protected: + MasterTest() { + std::vector targets; + SessionOptions options; + (*options.config.mutable_device_count())["CPU"] = 1; + (*options.config.mutable_device_count())["GPU"] = 0; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 2, &cluster_)); + master_ = grpc::MasterService::NewStub( + NewHostPortGrpcChannel(cluster_->targets()[0])); + } + + std::unique_ptr cluster_; + std::unique_ptr master_; + + // Helpers for MasterService.{CreateSession,RunStep,CloseSession} + // rpc calls. + + Status CreateSession(const GraphDef& def, string* handle, + int64* initial_version) { + ::grpc::ClientContext ctx; + CreateSessionRequest req; + *(req.mutable_graph_def()) = def; + // Invokes placement frequently. + req.mutable_config()->set_placement_period(1); + CreateSessionResponse resp; + const Status s = FromGrpcStatus(master_->CreateSession(&ctx, req, &resp)); + if (s.ok()) { + *handle = resp.session_handle(); + *initial_version = resp.graph_version(); + } + return s; + } + + Status ExtendSession(const string& handle, const GraphDef& def, + int64 current_version, int64* new_version) { + ::grpc::ClientContext ctx; + ExtendSessionRequest req; + req.set_session_handle(handle); + *(req.mutable_graph_def()) = def; + req.set_current_graph_version(current_version); + ExtendSessionResponse resp; + const Status s = FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp)); + if (s.ok()) { + *new_version = resp.new_graph_version(); + } + return s; + } + + Status RunStep(const string& handle, + const std::vector >& feed, + const std::map& fetch) { + ::grpc::ClientContext ctx; + RunStepRequest req; + req.set_session_handle(handle); + for (const auto& p : feed) { + const string& feed_name = p.first; + const Tensor* feed_tensor = p.second; + auto f = req.add_feed(); + f->set_name(feed_name); + feed_tensor->AsProtoTensorContent(f->mutable_tensor()); + } + for (const auto& p : fetch) { + const string& fetch_name = p.first; + req.add_fetch(fetch_name); + } + RunStepResponse resp; + const Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp)); + if (s.ok()) { + for (const auto& fetch_resp : resp.tensor()) { + auto it = fetch.find(fetch_resp.name()); + CHECK(it != fetch.end()); + CHECK(it->second->FromProto(fetch_resp.tensor())); + } + } + return s; + } + + Status CloseSession(const string& handle) { + ::grpc::ClientContext ctx; + CloseSessionRequest req; + req.set_session_handle(handle); + CloseSessionResponse resp; + return FromGrpcStatus(master_->CloseSession(&ctx, req, &resp)); + } + + Status Reset() { + ::grpc::ClientContext ctx; + ResetRequest req; + ResetResponse resp; + return FromGrpcStatus(master_->Reset(&ctx, req, &resp)); + } +}; + +TEST_F(MasterTest, CreateClose) { + GraphDef def; // Empty. + string handle; + int64 initial_version; + TF_ASSERT_OK(CreateSession(def, &handle, &initial_version)); + EXPECT_TRUE(errors::IsAborted(CloseSession("randombits"))); + EXPECT_TRUE(CloseSession(handle).ok()); +} + +TEST_F(MasterTest, ListDevices) { + ::grpc::ClientContext ctx; + ListDevicesRequest req; + ListDevicesResponse resp; + const Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp)); + TF_EXPECT_OK(s); + EXPECT_EQ(1, resp.local_device_size()); + EXPECT_EQ("CPU", resp.local_device(0).device_type()); +} + +TEST_F(MasterTest, Reset) { + GraphDef def; // Empty. + string s1, s2; + int64 initial_version1, initial_version2; + TF_ASSERT_OK(CreateSession(def, &s1, &initial_version1)); + TF_ASSERT_OK(CreateSession(def, &s2, &initial_version2)); + EXPECT_TRUE(Reset().ok()); + EXPECT_TRUE(errors::IsAborted(CloseSession(s1))); + EXPECT_TRUE(errors::IsAborted(CloseSession(s2))); +} + +TEST_F(MasterTest, Extend) { + GraphDef def_0; // Empty. + string handle; + int64 initial_version; + TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version)); + + Tensor A_expected(DT_FLOAT, TensorShape({2, 2})); + test::FillValues(&A_expected, {3.0, 2.0, -1.0, 0.0}); + + Tensor x_expected(DT_FLOAT, TensorShape({2, 1})); + test::FillValues(&x_expected, {2.0, 2.0}); + + Graph graph_1(OpRegistry::Global()); + test::graph::Constant(&graph_1, A_expected, "A"); + GraphDef def_1; + test::graph::ToGraphDef(&graph_1, &def_1); + int64 version_1; + TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1)); + EXPECT_GT(version_1, initial_version); + Tensor A(DT_FLOAT, TensorShape({2, 2})); + TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}})); + test::ExpectTensorEqual(A, A_expected); + + Graph graph_2(OpRegistry::Global()); + test::graph::Constant(&graph_2, x_expected, "x"); + GraphDef def_2; + test::graph::ToGraphDef(&graph_2, &def_2); + int64 version_2; + EXPECT_TRUE(errors::IsAborted( + ExtendSession("randombits", def_2, version_1, &version_2))); + TF_ASSERT_OK(ExtendSession(handle, def_2, version_1, &version_2)); + EXPECT_GT(version_2, version_1); + + Tensor x(DT_FLOAT, TensorShape({2, 1})); + TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"x:0", &x}})); + test::ExpectTensorEqual(A, A_expected); + test::ExpectTensorEqual(x, x_expected); + + TF_ASSERT_OK(CloseSession(handle)); +} + +TEST_F(MasterTest, ExtendUpdateStatefulFails) { + GraphDef def_0; // Empty. + string handle; + int64 initial_version; + TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version)); + + Graph graph_1(OpRegistry::Global()); + test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512})); + GraphDef def_1; + test::graph::ToGraphDef(&graph_1, &def_1); + + int64 version_1, version_2; + TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1)); + EXPECT_GT(version_1, initial_version); + EXPECT_TRUE(errors::IsInvalidArgument( + ExtendSession(handle, def_1, version_1, &version_2))); + TF_ASSERT_OK(CloseSession(handle)); +} + +TEST_F(MasterTest, ExtendTwiceFails) { + GraphDef def_0; // Empty. + string handle; + int64 initial_version; + TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version)); + + Graph graph_1(OpRegistry::Global()); + test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512})); + GraphDef def_1; + test::graph::ToGraphDef(&graph_1, &def_1); + + int64 version_1; + TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1)); + EXPECT_GT(version_1, initial_version); + EXPECT_TRUE(errors::IsAborted( + ExtendSession(handle, def_1, initial_version, &version_1))); + TF_ASSERT_OK(CloseSession(handle)); +} + +TEST_F(MasterTest, ConcurrentExtendOnlyOneSucceeds) { + GraphDef def_0; // Empty. + string handle; + int64 initial_version; + TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version)); + + Graph graph_1(OpRegistry::Global()); + test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512})); + GraphDef def_1; + test::graph::ToGraphDef(&graph_1, &def_1); + + Notification n; + mutex mu; + int succeeded = 0; + int failed = 0; + auto extend_fn = [this, handle, def_1, initial_version, &n, &mu, &succeeded, + &failed]() { + n.WaitForNotification(); + int64 new_version; + Status s = ExtendSession(handle, def_1, initial_version, &new_version); + EXPECT_TRUE(s.ok() || errors::IsAborted(s)); + { + mutex_lock l(mu); + if (s.ok()) { + ++succeeded; + } else { + ++failed; + } + } + }; + + // Run 100 concurrent Extend calls and expect only one to succeed. + { + thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 100); + for (int i = 0; i < 100; ++i) { + thread_pool.Schedule(extend_fn); + } + n.Notify(); + } + + EXPECT_EQ(failed, 99); + EXPECT_EQ(succeeded, 1); + TF_ASSERT_OK(CloseSession(handle)); +} + +TEST_F(MasterTest, ConcurrentExtendAndRun) { + Graph graph_0(OpRegistry::Global()); + Tensor a_tensor(DT_FLOAT, TensorShape({2, 2})); + test::FillValues(&a_tensor, {3, 2, -1, 0}); + test::graph::Constant(&graph_0, a_tensor, "A"); + GraphDef def_0; + test::graph::ToGraphDef(&graph_0, &def_0); + + string handle; + int64 initial_version; + TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version)); + + Graph graph_1(OpRegistry::Global()); + Tensor b_tensor(DT_FLOAT, TensorShape({2, 2})); + test::FillValues(&b_tensor, {1, 0, 0, 1}); + test::graph::Constant(&graph_1, b_tensor, "B"); + GraphDef def_1; + test::graph::ToGraphDef(&graph_1, &def_1); + + Notification extend_done; + Notification extend_can_start; + + auto get_a_fn = [this, handle, &extend_done]() { + Tensor A(DT_FLOAT, TensorShape({2, 2})); + while (!extend_done.HasBeenNotified()) { + TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}})); + } + // Run at least once after the Extend has completed. + TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}})); + }; + + auto get_a_and_b_fn = [this, handle, &extend_done, &extend_can_start]() { + Tensor A(DT_FLOAT, TensorShape({2, 2})); + Tensor B(DT_FLOAT, TensorShape({2, 2})); + + // Run at least once before the Extend has completed. + EXPECT_TRUE( + errors::IsNotFound(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}))); + extend_can_start.Notify(); + + // Concurrent with the Extend, we will either fail (as above), or + // succeed (as below). + while (!extend_done.HasBeenNotified()) { + Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}); + EXPECT_TRUE(errors::IsNotFound(s) || s.ok()); + } + + // Run at least once after the Extend has completed. + TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}})); + }; + + auto extend_fn = [this, handle, def_1, initial_version, &extend_done, + &extend_can_start]() { + extend_can_start.WaitForNotification(); + int64 version_1; + TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1)); + extend_done.Notify(); + }; + + { + thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 3); + thread_pool.Schedule(get_a_fn); + thread_pool.Schedule(get_a_and_b_fn); + thread_pool.Schedule(extend_fn); + } + + TF_ASSERT_OK(CloseSession(handle)); +} + +TEST_F(MasterTest, EigenProblem) { + // A = [3 2; -1 0]; x = rand(2, 1); + // for i=1:100; x = A * x; end + // We'll try to compute the largest eigenvalue for A. + Graph graph(OpRegistry::Global()); + Tensor a_tensor(DT_FLOAT, TensorShape({2, 2})); + // Store rows [3, 2] and [-1, 0] in row major format. + test::FillValues(&a_tensor, {3, 2, -1, 0}); + Node* a_node = test::graph::Constant(&graph, a_tensor); + + // x is from the feed. + Tensor x_tensor(DT_FLOAT, TensorShape({2, 1})); + test::FillValues(&x_tensor, {0, 0}); + Node* x_node = test::graph::Constant(&graph, x_tensor); + + // y = A * x + Node* y_node = test::graph::Matmul(&graph, a_node, x_node, false, false); + + GraphDef def; + test::graph::ToGraphDef(&graph, &def); + + string handle; + int64 initial_version; + TF_CHECK_OK(CreateSession(def, &handle, &initial_version)); + + // Temps supporting the computation of the convergence condition. + const Eigen::array sum_along_dim(0); + const Eigen::array matrix_transpose({1, 0}); + Tensor x(DT_FLOAT, TensorShape({2, 1})); + Tensor y(DT_FLOAT, TensorShape({2, 1})); + Eigen::Tensor y_square_sum; + Eigen::Tensor y_normalized(2, 1); + y_normalized.setRandom(); + Eigen::Tensor error_square_sum; + float lambda; + + // The computation loop. + bool converged = false; + while (!converged) { + // Run one step of the graph. + auto x_matrix = x.matrix(); + x_matrix = y_normalized; + TF_EXPECT_OK( + RunStep(handle, {{x_node->name(), &x}}, {{y_node->name() + ":0", &y}})); + auto y_matrix = y.matrix(); + + // Client code computes the convergence condition. + { + lambda = y_matrix(0, 0) / x_matrix(0, 0); + y_square_sum = y.matrix().square().sum(sum_along_dim); + const float norm = static_cast(sqrt(y_square_sum(0))); + y_normalized = y_matrix * (1 / norm); + error_square_sum = (x_matrix - y_normalized).square().sum(sum_along_dim); + VLOG(1) << "x = [" << x_matrix.shuffle(matrix_transpose) << "] y = [" + << y_matrix.shuffle(matrix_transpose) << "] lambda = " << lambda; + converged = sqrt(error_square_sum(0)) < 1e-10; + } + } + EXPECT_NEAR(lambda, 2.0, 0.01); + TF_EXPECT_OK(CloseSession(handle)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/process_util.cc b/tensorflow/core/distributed_runtime/process_util.cc new file mode 100644 index 00000000000..8f97382cf80 --- /dev/null +++ b/tensorflow/core/distributed_runtime/process_util.cc @@ -0,0 +1,69 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/process_util.h" + +#include + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/host_info.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { + +static thread::ThreadPool* InitComputePool(const SessionOptions& options) { + int32 inter_op_parallelism_threads = + options.config.inter_op_parallelism_threads(); + if (inter_op_parallelism_threads == 0) { + // Default to using the number of cores available in the process. + inter_op_parallelism_threads = port::NumSchedulableCPUs(); + } + + return new thread::ThreadPool(Env::Default(), "Compute", + inter_op_parallelism_threads); +} + +} // namespace + +thread::ThreadPool* ComputePool(const SessionOptions& options) { + static thread::ThreadPool* compute_pool = InitComputePool(options); + return compute_pool; +} + +void SchedClosure(std::function closure) { + if (port::Tracing::IsActive()) { + const uint64 id = port::Tracing::UniqueId(); + port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure, + id); + std::function wrapper = [closure, id]() { + port::Tracing::ScopedActivity region( + port::Tracing::EventCategory::kRunClosure, id); + closure(); + }; + Env::Default()->SchedClosure(wrapper); + } else { + Env::Default()->SchedClosure(closure); + } +} + +void SchedNonBlockingClosureAfter(int micros, std::function closure) { + Env::Default()->SchedClosureAfter(micros, closure); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/process_util.h b/tensorflow/core/distributed_runtime/process_util.h new file mode 100644 index 00000000000..fb20e88b1ea --- /dev/null +++ b/tensorflow/core/distributed_runtime/process_util.h @@ -0,0 +1,39 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_ + +#include + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// Returns a process-wide ThreadPool for scheduling compute operations +// using 'options'. Caller does not take ownership over threadpool. +thread::ThreadPool* ComputePool(const SessionOptions& options); + +// Schedule "closure" in the default thread queue. +void SchedClosure(std::function closure); + +// Schedule "closure" after the given number of microseconds in the +// fixed-size ThreadPool used for non-blocking compute tasks. +void SchedNonBlockingClosureAfter(int micros, std::function closure); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_ diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc new file mode 100644 index 00000000000..387b9e4492b --- /dev/null +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -0,0 +1,91 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/remote_device.h" + +#include +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +using std::placeholders::_1; + +// TODO(zhifengc): We need to consolidate (full/partial) device name +// parsing into one place. +// +// Parses and returns the local device part (e.g., cpu:0, gpu:4). +string GetLocalDeviceName(StringPiece fullname) { + auto pos = fullname.rfind('/'); + CHECK_NE(pos, StringPiece::npos); + fullname.remove_prefix(pos + 1); + return fullname.ToString(); +} + +class RemoteDevice : public Device { + public: + RemoteDevice(Env* env, const DeviceAttributes& da, WorkerInterface* wi) + : Device(env, da, nullptr), + local_dev_name_(GetLocalDeviceName(da.name())), + wi_(wi) {} + + ~RemoteDevice() override { delete wi_; } + Status Sync() override { return Status::OK(); } + Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } + + private: + const string local_dev_name_; + WorkerInterface* wi_; + + TF_DISALLOW_COPY_AND_ASSIGN(RemoteDevice); +}; + +void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, + const string& worker_name, NewRemoteDevicesDone done) { + WorkerInterface* wi = worker_cache->CreateWorker(worker_name); + if (wi == nullptr) { + std::vector empty; + done(errors::NotFound("Device ", worker_name, " is not found."), &empty); + return; + } + struct Call { + GetStatusRequest req; + GetStatusResponse resp; + }; + Call* call = new Call; + auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) { + std::vector remote_devices; + if (s.ok()) { + remote_devices.reserve(call->resp.device_attributes_size()); + for (const DeviceAttributes& da : call->resp.device_attributes()) { + auto d = + new RemoteDevice(env, da, worker_cache->CreateWorker(worker_name)); + remote_devices.push_back(d); + } + } + done(s, &remote_devices); + delete wi; + delete call; + }; + wi->GetStatusAsync(&call->req, &call->resp, cb); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/remote_device.h b/tensorflow/core/distributed_runtime/remote_device.h new file mode 100644 index 00000000000..aeefeda048b --- /dev/null +++ b/tensorflow/core/distributed_runtime/remote_device.h @@ -0,0 +1,48 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_ + +#include +#include +#include +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +class Device; +class Env; +class WorkerCacheInterface; + +// NewRemoteDevices discovers available devices on the +// 'remote_worker'. The implementation uses 'channel_cache' to +// discover how to communicate with the 'remote_worker' (via gRPC, for +// example). +// +// NewRemoteDevices does not block. +// +// On success, the 'done' callback is given the OK status and a vector +// of Device*. The caller should take ownership of these devices. +// +// Otherwise, the 'done' callback is given an error status and the +// vector is empty. +typedef std::function*)> + NewRemoteDevicesDone; +void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, + const string& remote_worker, NewRemoteDevicesDone done); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_ diff --git a/tensorflow/core/distributed_runtime/remote_device_test.cc b/tensorflow/core/distributed_runtime/remote_device_test.cc new file mode 100644 index 00000000000..c575a764718 --- /dev/null +++ b/tensorflow/core/distributed_runtime/remote_device_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/remote_device.h" + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +const char* const kSession = "remote_session"; + +class RemoteDeviceTest : public ::testing::Test { + protected: + string remote_name_; + std::unique_ptr worker_cache_; + std::unique_ptr wi_; + std::vector devices_; + std::unique_ptr cluster_; + + RemoteDeviceTest() { + SessionOptions options; + (*options.config.mutable_device_count())["CPU"] = 2; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 1, &cluster_)); + const string& hostport = cluster_->targets()[0]; + string host; + int port; + CHECK(RE2::FullMatch(hostport, "(.+):(\\d+)", &host, &port)); + GrpcChannelSpec spec; + spec.AddHostPortsJob("localhost", {hostport}, 1); + worker_cache_.reset(NewGrpcWorkerCache(NewGrpcChannelCache(spec))); + remote_name_ = strings::StrCat("/job:", host, "/replica:0/task:0"); + wi_.reset(worker_cache_->CreateWorker(remote_name_)); + } + + void SetUp() override { + Notification n; + NewRemoteDevices(Env::Default(), worker_cache_.get(), remote_name_, + [&n, this](const Status& s, std::vector* found) { + TF_CHECK_OK(s); + devices_ = *found; + n.Notify(); + }); + n.WaitForNotification(); + EXPECT_EQ(devices_.size(), 2); + std::sort(devices_.begin(), devices_.end(), [](Device* a, Device* b) { + return a->name().compare(b->name()) < 0; + }); + } + + void TearDown() override { + for (auto d : devices_) delete d; + } +}; + +TEST_F(RemoteDeviceTest, GetStatus) { + // We know what the testlib's fake server does. + EXPECT_EQ(devices_[0]->name(), strings::StrCat(remote_name_, "/cpu:0")); + EXPECT_EQ(devices_[0]->attributes().device_type(), + DeviceType(DEVICE_CPU).type()); + EXPECT_EQ(devices_[0]->attributes().memory_limit(), 256 << 20); + EXPECT_EQ(devices_[1]->name(), strings::StrCat(remote_name_, "/cpu:1")); + EXPECT_EQ(devices_[1]->attributes().memory_limit(), 256 << 20); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h new file mode 100644 index 00000000000..6a71bb04b40 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h @@ -0,0 +1,79 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_ + +#include + +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// RendezvousMgr keeps track of a set of local rendezvous instances. +// All tensors sent by this worker are buffered in a RendezvousMgr +// until the tensor is received. Each global unique "step_id" +// corresponds to one local rendezvous instance managed by a +// RendezvousMgr. +// +// E.g., +// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); +// fork execution of an graph executor using "rendez" on thread 1; +// fork execution of another graph executor using "rendez" on thread 2; +// ... +// join threads 1 and 2; +// +// In the example above, execution in thread 1 and 2 communicates with +// each other by send/recv operations through the "rend". +// +// Tensors sent and recved through rendezvous managed by this +// RendezvousMgr must have keys generated by Rendezvous::CreateKey. +class RendezvousMgrInterface { + public: + RendezvousMgrInterface() {} + virtual ~RendezvousMgrInterface() {} + + // Returns Rendezvous supporting send and recv among workers in the + // "step_id". The caller takes ownership of one reference on the + // returned Rendezvous instance. + virtual Rendezvous* Find(int64 step_id) = 0; + + // Finds the local rendezvous instance for the "step_id". Runs + // "done" when the tensor for "key" is produced or an error occurs. + // + // This method is used by the rpc handler of RecvTensor. + virtual void RecvLocalAsync(int64 step_id, const string& key, + Rendezvous::DoneCallback done) = 0; + + // Synchronous wrapper for RecvLocalAsync. + virtual Status RecvLocal(int64 step_id, const string& key, Tensor* val, + bool* is_dead) = 0; + + // Removes rendezvous for "step_id". + // + // TODO(zhifengc): Have a background thread in worker that + // periodically calls CleanupAll(). + virtual void Cleanup(int64 step_id) = 0; + + // Removes all rendezvous. + virtual void CleanupAll() = 0; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD new file mode 100644 index 00000000000..3166c942592 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -0,0 +1,341 @@ +# Description: +# RPC communication interfaces and implementations for TensorFlow. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cuda_library", + "tf_cc_tests", +) + +# For platform specific build config +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) + +package(default_visibility = [ + "//tensorflow:internal", +]) + +cc_library( + name = "grpc_util", + srcs = [], + hdrs = ["grpc_util.h"], + deps = [ + "@grpc//:grpc++_unsecure", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "grpc_client_cq_tag", + srcs = [], + hdrs = ["grpc_client_cq_tag.h"], + deps = [ + "@grpc//:grpc++_unsecure", + ":grpc_util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "grpc_remote_worker", + srcs = ["grpc_remote_worker.cc"], + hdrs = ["grpc_remote_worker.h"], + deps = [ + "@grpc//:grpc++_unsecure", + ":grpc_client_cq_tag", + ":grpc_util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:worker_proto_cc", + "//tensorflow/core:worker_service_proto_cc", + "//tensorflow/core/distributed_runtime:process_util", + "//tensorflow/core/distributed_runtime:worker_cache_logger", + "//tensorflow/core/distributed_runtime:worker_interface", + ], +) + +cc_library( + name = "grpc_channel", + srcs = ["grpc_channel.cc"], + hdrs = ["grpc_channel.h"], + deps = [ + "@grpc//:grpc++_unsecure", + ":grpc_util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "grpc_call", + srcs = [], + hdrs = ["grpc_call.h"], + deps = [ + "@grpc//:grpc++_unsecure", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "async_service_interface", + srcs = [], + hdrs = ["async_service_interface.h"], + deps = [], +) + +cc_library( + name = "grpc_worker_cache", + srcs = ["grpc_worker_cache.cc"], + hdrs = ["grpc_worker_cache.h"], + deps = [ + ":grpc_channel", + ":grpc_client_cq_tag", + ":grpc_remote_worker", + "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime:worker_cache", + "//tensorflow/core/distributed_runtime:worker_cache_logger", + "//tensorflow/core/distributed_runtime:worker_cache_partial", + "//tensorflow/core/distributed_runtime:worker_interface", + ], +) + +cc_library( + name = "grpc_worker_service", + srcs = ["grpc_worker_service.cc"], + hdrs = ["grpc_worker_service.h"], + deps = [ + "@grpc//:grpc++_unsecure", + ":async_service_interface", + ":grpc_call", + ":grpc_util", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core:worker_proto_cc", + "//tensorflow/core:worker_service_proto_cc", + "//tensorflow/core/distributed_runtime:graph_mgr", + "//tensorflow/core/distributed_runtime:process_util", + "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", + "//tensorflow/core/distributed_runtime:worker_cache", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime:worker_interface", + ], +) + +cc_library( + name = "grpc_remote_master", + srcs = ["grpc_remote_master.cc"], + hdrs = ["grpc_remote_master.h"], + deps = [ + "@grpc//:grpc++_unsecure", + ":grpc_util", + "//tensorflow/core:lib", + "//tensorflow/core:master_proto_cc", + "//tensorflow/core:master_service_proto_cc", + "//tensorflow/core/distributed_runtime:master_interface", + ], + alwayslink = 1, +) + +cc_library( + name = "grpc_master_service", + srcs = ["grpc_master_service.cc"], + hdrs = ["grpc_master_service.h"], + deps = [ + "@grpc//:grpc++_unsecure", + ":async_service_interface", + ":grpc_call", + ":grpc_util", + "//tensorflow/core:lib", + "//tensorflow/core:master_proto_cc", + "//tensorflow/core:master_service_proto_cc", + "//tensorflow/core/distributed_runtime:master", + "//tensorflow/core/distributed_runtime:master_interface", + ], + alwayslink = 1, +) + +cc_library( + name = "rpc_rendezvous_mgr", + srcs = ["rpc_rendezvous_mgr.cc"], + hdrs = ["rpc_rendezvous_mgr.h"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:process_util", + "//tensorflow/core/distributed_runtime:worker_cache", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime:worker_interface", + ], +) + +cc_library( + name = "grpc_server_lib", + srcs = [ + "grpc_server_lib.cc", + ], + hdrs = ["grpc_server_lib.h"], + deps = [ + "@grpc//:grpc++_unsecure", + ":async_service_interface", + ":grpc_channel", + ":grpc_master_service", + ":grpc_worker_cache", + ":grpc_worker_service", + ":rpc_rendezvous_mgr", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime:graph_mgr", + "//tensorflow/core/distributed_runtime:master_env", + "//tensorflow/core/distributed_runtime:master_session", + "//tensorflow/core/distributed_runtime:process_util", + "//tensorflow/core/distributed_runtime:worker_env", + ], +) + +cc_binary( + name = "grpc_tensorflow_server", + srcs = [ + "grpc_tensorflow_server.cc", + ], + deps = [ + "@grpc//:grpc++_unsecure", + ":grpc_server_lib", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +tf_cuda_library( + name = "grpc_testlib_ops", + testonly = 1, + srcs = ["grpc_testlib_ops.cc"], + linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel + deps = [ + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +cc_binary( + name = "grpc_testlib_server", + testonly = 1, + srcs = [ + "grpc_testlib_server.cc", + ], + deps = [ + "@grpc//:grpc++_unsecure", + ":grpc_server_lib", + ":grpc_testlib_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +tf_cuda_library( + name = "grpc_testlib", + testonly = 1, + srcs = ["grpc_testlib.cc"], + hdrs = ["grpc_testlib.h"], + data = [ + ":grpc_testlib_server", + ], + deps = [ + ":grpc_session", + ":grpc_testlib_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core:test", + ], + alwayslink = 1, +) + +cc_library( + name = "grpc_session", + srcs = ["grpc_session.cc"], + hdrs = ["grpc_session.h"], + deps = [ + ":grpc_channel", + ":grpc_remote_master", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:master_proto_cc", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core/distributed_runtime:master_interface", + ], + alwayslink = 1, +) + +tf_cc_tests( + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + tests = [ + "grpc_channel_test.cc", + "grpc_session_test.cc", + "rpc_rendezvous_mgr_test.cc", + ], + deps = [ + ":grpc_channel", + ":grpc_session", + ":grpc_testlib", + ":rpc_rendezvous_mgr", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:master_proto_cc", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/distributed_runtime:process_util", + ], +) diff --git a/tensorflow/core/distributed_runtime/rpc/async_service_interface.h b/tensorflow/core/distributed_runtime/rpc/async_service_interface.h new file mode 100644 index 00000000000..2f453b048e8 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/async_service_interface.h @@ -0,0 +1,37 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_ + +namespace tensorflow { + +// Represents an abstract asynchronous service that handles incoming +// RPCs with a polling loop. +class AsyncServiceInterface { + public: + virtual ~AsyncServiceInterface() {} + + // A blocking method that should be called to handle incoming RPCs. + // This method will block until the service is shutdown, which + // depends on the implementation of the service. + virtual void HandleRPCsLoop() = 0; + + // TODO(mrry): Add a clean shutdown method? +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_call.h b/tensorflow/core/distributed_runtime/rpc/grpc_call.h new file mode 100644 index 00000000000..11f139ca03c --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_call.h @@ -0,0 +1,227 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ + +#include "tensorflow/core/platform/macros.h" + +#include "external/grpc/include/grpc++/grpc++.h" +#include "external/grpc/include/grpc++/server_builder.h" + +namespace tensorflow { + +// CALL STRUCTURES +// =============== +// +// Each pending (incoming) request corresponds to a call object that +// encapsulates the state of the call. Templates and +// pointers-to-member functions are used to avoid boilerplate and +// redundant closure creation. The class hierarchy is as follows: +// +// * `UntypedCall`: The base class represents a call that +// could be associated with any of the methods on a service of type +// `Service`. Also defines a `Tag` nested class that can be used as +// the tag in a `grpc::CompletionQueue`. Each class that +// instantiates `Service` should have a completion queue polling +// loop that knows about `UntypedCall::Tag` objects, and +// invokes their `OnCompleted()` method to continue processing. +// +// * `Call`: This class extends +// `UntypedCall` and is additionally parameterized by the +// gRPC-generated asynchronous service class, and the request and +// response message types. It defines the state associated with a +// call (whose type depends on the message types), and stores a +// pointer to a `Service::HandleFoo()` handler method. Each +// `Service::HandleFoo()` method knows about the corresponding +// `Call` type, in order to access its state, and invoke its +// `SendResponse()` method. +// +// The lifecycle of a call object is as follows. +// +// 1. A `Service` creates a `Call` for a particular method and +// enqueues it in its completion queue (via an +// `UntypedCall::Tag`). +// +// 2. When the tag is returned from `cq_->Next()`, the +// `UntypedCall::RequestReceived()` method is invoked and takes +// ownership of the call object. This indirectly invokes the +// appropriate handler method on `Service`. +// +// 3. After the response has been written (perhaps in another thread), +// the `Call::SendResponse()` method is invoked. It transfers +// ownership of the call object back to the completion queue (via +// an `UntypedCall::Tag`). +// +// 4. When the response has been sent, the tag is returned from +// `cq_->Next()`, and the call object is deleted. + +// Represents a pending request with unknown message types. +template +class UntypedCall : public core::RefCounted { + public: + virtual ~UntypedCall() {} + + // The implementation of this method should use `service` to handle + // an incoming request, and (perhaps asynchronously) send the + // response. + // + // One reference on `this` is transferred to the callee, and the + // callee is responsible for releasing it (typically via + // `Call::SendResponse()`). + // + // `ok` is true if the request was received in a "regular event", + // otherwise false. + virtual void RequestReceived(Service* service, bool ok) = 0; + + // This method will be called when the response has been sent by + // `service` and the call is no longer used. + // + // `ok` is true if the response sending completed as a "regular + // event", otherwise it is false. + void ResponseSent(Service* service, bool ok) {} + + // This method will be called either (i) when the server is notified + // that the request has been cancelled, or (ii) when the request completes + // normally. The implementation should distinguish these cases by querying + // the `grpc::ServerContext` associated with the request. + virtual void RequestCancelled(Service* service, bool ok) = 0; + + // Associates a tag in a `::grpc::CompletionQueue` with a callback + // for an incoming RPC. A Tag owns a reference on the corresponding + // Call object. + class Tag { + public: + using Callback = void (UntypedCall::*)(Service*, bool); + + // Creates a new `Tag` for the given `UntypedCall`. When the + // request associated with this tag is complete, `callback` will + // be called. + Tag(UntypedCall* call, Callback callback) + : call_(call), callback_(callback) { + call_->Ref(); + } + + ~Tag() { call_->Unref(); } + + // Calls the callback associated with this tag. + // + // The callback takes ownership of `this->call_`. + void OnCompleted(Service* service, bool ok) { + (call_->*callback_)(service, ok); + } + + private: + UntypedCall* call_; // `this` owns one reference. + Callback callback_; + }; +}; + +// Represents a pending call with known request and response message +// types, and a known request-handling method. +template +class Call : public UntypedCall { + public: + // Represents the generic signature of a generated + // `GrpcService::RequestFoo()` method, where `Foo` is the name of an + // RPC method. + using EnqueueFunction = void (GrpcService::*)( + ::grpc::ServerContext*, RequestMessage*, + ::grpc::ServerAsyncResponseWriter*, + ::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*); + + // Represents the generic signature of a `Service::HandleFoo()` + // method, where `Foo` is the name of an RPC method. + using HandleRequestFunction = void (Service::*)( + Call*); + + Call(HandleRequestFunction handle_request_function) + : handle_request_function_(handle_request_function), responder_(&ctx_) {} + + virtual ~Call() {} + + void RequestReceived(Service* service, bool ok) override { + if (ok) { + this->Ref(); + (service->*handle_request_function_)(this); + } + } + + void SendResponse(::grpc::Status status) { + responder_.Finish(response, status, + new typename UntypedCall::Tag( + this, &UntypedCall::ResponseSent)); + this->Unref(); + } + + void RequestCancelled(Service* service, bool ok) override { + if (ctx_.IsCancelled()) { + mutex_lock l(mu_); + if (cancel_callback_) { + cancel_callback_(); + } + } + } + + // Registers `callback` as the function that should be called if and when this + // call is cancelled by the client. + void SetCancelCallback(std::function callback) { + mutex_lock l(mu_); + cancel_callback_ = callback; + } + + // Clears any cancellation callback that has been registered for this call. + void ClearCancelCallback() { + mutex_lock l(mu_); + cancel_callback_ = nullptr; + } + + // Enqueues a new request for the given service on the given + // completion queue, using the given `enqueue_function`. + // + // The request will be handled with the given + // `handle_request_function`. + static void EnqueueRequest(GrpcService* grpc_service, + ::grpc::ServerCompletionQueue* cq, + EnqueueFunction enqueue_function, + HandleRequestFunction handle_request_function) { + auto call = new Call( + handle_request_function); + + call->ctx_.AsyncNotifyWhenDone(new typename UntypedCall::Tag( + call, &UntypedCall::RequestCancelled)); + + (grpc_service->*enqueue_function)( + &call->ctx_, &call->request, &call->responder_, cq, cq, + new typename UntypedCall::Tag( + call, &UntypedCall::RequestReceived)); + call->Unref(); + } + + RequestMessage request; + ResponseMessage response; + + private: + HandleRequestFunction handle_request_function_; + ::grpc::ServerContext ctx_; + ::grpc::ServerAsyncResponseWriter responder_; + mutex mu_; + std::function cancel_callback_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc new file mode 100644 index 00000000000..f9492114b69 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc @@ -0,0 +1,314 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" + +#include + +#include "external/grpc/include/grpc++/create_channel.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { +RE2* kTargetRE = new RE2("^/job:([^/]+)/replica:([0-9]+)/task:([0-9]+)$"); +RE2* kHostPortRE = new RE2("([^:/]+):(\\d+)"); +RE2* kSparseHostPortRE = new RE2("(\\d+):([^:/]+):(\\d+)"); + +string MakeAddress(const string& job, int replica, int task) { + return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task); +} + +} // namespace + +SharedGrpcChannelPtr NewHostPortGrpcChannel(const string& target) { + // TODO(mrry): Implement secure channels. + return ::grpc::CreateChannel(target, ::grpc::InsecureChannelCredentials()); +} + +Status GrpcChannelSpec::AddHostPortsJob(const string& job_id, + const std::vector& host_ports, + int tasks_per_replica) { + if (!job_ids_.insert(job_id).second) { + return errors::InvalidArgument( + "Duplicate job ID in cluster specification: ", job_id); + } + HostPortsJob job; + job.job_id = job_id; + for (const string& host_port : host_ports) { + string host; + int port; + if (!RE2::FullMatch(host_port, *kHostPortRE, &host, &port)) { + return errors::InvalidArgument("Could not interpret \"", host_port, + "\" as a host-port pair."); + } + } + job.host_ports = host_ports; + job.tasks_per_replica = tasks_per_replica; + host_ports_jobs_.push_back(job); + return Status::OK(); +} + +GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec) { + const int num_jobs = spec.host_ports_jobs().size(); + if (!num_jobs) { + LOG(ERROR) << "Empty channel spec."; + return nullptr; + } + std::vector caches; + caches.reserve(num_jobs); + for (const GrpcChannelSpec::HostPortsJob& job : spec.host_ports_jobs()) { + caches.push_back(NewHostPortsGrpcChannelCache(job.job_id, job.host_ports, + job.tasks_per_replica)); + } + return caches.size() == 1 ? caches[0] : NewMultiGrpcChannelCache(caches); +} + +// GrpcChannelCache that caches results to FindWorkerChannel() calls. +class CachingGrpcChannelCache : public GrpcChannelCache { + public: + CachingGrpcChannelCache() {} + + ~CachingGrpcChannelCache() override {} + + SharedGrpcChannelPtr FindWorkerChannel(const string& target) override { + SharedGrpcChannelPtr ch = nullptr; + { + mutex_lock l(mu_); // could use reader lock + ch = gtl::FindPtrOrNull(channels_, target); + if (ch) { + return ch; + } + } + ch = FindChannelOnce(target); + if (ch) { + mutex_lock l(mu_); + channels_.insert({target, ch}); + } + return ch; + } + + protected: + // Find the ClientChannel for "target". Only called when no channel was + // found in the channels_ cache for "target". A non nullptr result will be + // cached in channels_. + virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0; + + private: + // TODO(zhifengc): Eviction when the map becomes too big. + mutex mu_; + std::unordered_map channels_ GUARDED_BY(mu_); +}; + +// A ChannelCache that is the union of multiple ChannelCaches. +// Takes ownership of the caches passed to the constructor. +class MultiGrpcChannelCache : public CachingGrpcChannelCache { + public: + explicit MultiGrpcChannelCache(const std::vector& caches) + : CachingGrpcChannelCache(), caches_(caches) {} + + ~MultiGrpcChannelCache() override { + for (GrpcChannelCache* cache : caches_) { + delete cache; + } + } + + void ListWorkers(std::vector* workers) override { + for (GrpcChannelCache* cache : caches_) { + cache->ListWorkers(workers); + } + } + + string TranslateTask(const string& target) override { + mutex_lock l(mu_); // could use reader lock + GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target); + if (cache == nullptr) { + for (GrpcChannelCache* c : caches_) { + string r = c->TranslateTask(target); + if (!r.empty()) { + target_caches_.insert({target, c}); + cache = c; + break; + } + } + } + CHECK(cache) << "Could not find GrpcChannelCache holding channel for " + << target; + return cache->TranslateTask(target); + } + + protected: + SharedGrpcChannelPtr FindChannelOnce(const string& target) override { + for (GrpcChannelCache* cache : caches_) { + SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target)); + if (ch) { + mutex_lock l(mu_); + target_caches_.insert({target, cache}); + return ch; + } + } + return nullptr; + } + + private: + // List of channels used by this MultiGrpcChannelCache. + const std::vector caches_; + + mutex mu_; + // Cache of channels keyed by the target they are handling. + // The same GrpcChannelCache can appear multiple times in the cache. + std::unordered_map target_caches_ GUARDED_BY(mu_); +}; + +GrpcChannelCache* NewMultiGrpcChannelCache( + const std::vector& caches) { + return new MultiGrpcChannelCache(caches); +} + +class HostPortsGrpcChannelCache : public CachingGrpcChannelCache { + public: + HostPortsGrpcChannelCache(const string& job_id, + const std::vector& host_ports, + int tasks_per_replica) + : job_id_(job_id), + host_ports_(BuildDenseHostPortsList(host_ports, tasks_per_replica)), + tasks_per_replica_(tasks_per_replica) { + LOG(INFO) << "Initialize HostPortsGrpcChannelCache for job " << job_id + << " -> {" << str_util::Join(host_ports, ", ") << "}"; + } + ~HostPortsGrpcChannelCache() override {} + + void ListWorkers(std::vector* workers) override { + int num_host_ports = 0; + for (size_t i = 0; i < host_ports_.size(); ++i) { + if (!host_ports_[i].empty()) { + ++num_host_ports; + } + } + workers->reserve(workers->size() + num_host_ports); + for (size_t i = 0; i < host_ports_.size(); ++i) { + if (!host_ports_[i].empty()) { + workers->emplace_back(MakeAddress(job_id_, i / tasks_per_replica_, + i % tasks_per_replica_)); + } + } + } + + string TranslateTask(const string& target) override { + RegexpStringPiece job; + int32 replica; + int32 task; + if (!RE2::FullMatch(target, *kTargetRE, &job, &replica, &task)) { + LOG(WARNING) << "Invalid target: " << target; + return ""; + } + if (job != job_id_) { + return ""; + } + if (task >= tasks_per_replica_) { + LOG(WARNING) << "Task out of bounds for job " << job_id_ << ": " << task; + return ""; + } + const size_t i = replica * tasks_per_replica_ + task; + if (i >= host_ports_.size()) { + LOG(WARNING) << "Replica/task out of bounds for job " << job_id_ << ": " + << target; + return ""; + } + if (host_ports_[i].empty()) { + LOG(WARNING) << "Replica/task not in sparse index:host:port list for job " + << job_id_ << ": " << target; + return ""; + } + return host_ports_[i]; + } + + protected: + static std::vector BuildDenseHostPortsList( + const std::vector& host_ports, int tasks_per_replica) { + std::map sparse_host_ports; + for (const string& host_port : host_ports) { + int i = -1; + string host; + int port = -1; + if (RE2::FullMatch(host_port, *kSparseHostPortRE, &i, &host, &port)) { + CHECK_LE(0, i); + CHECK_LE(0, port); + CHECK(sparse_host_ports.find(i) == sparse_host_ports.end()) + << "Duplicate index " << i << ": {" + << str_util::Join(host_ports, ", ") << "}"; + sparse_host_ports[i] = strings::StrCat(host, ":", port); + } else { + CHECK(RE2::FullMatch(host_port, *kHostPortRE, &host, &port)) + << host_port + << " does not look like a host:port or an index:host:port"; + } + } + + if (sparse_host_ports.empty()) { + // The input is a dense list; return it directly. + return host_ports; + } + + // The input is a sparse list. Convert it to a dense list. + CHECK_EQ(host_ports.size(), sparse_host_ports.size()) + << "Mix of host:port and index:host:port: {" + << str_util::Join(host_ports, ", ") << "}"; + int num_tasks = sparse_host_ports.rbegin()->first + 1; + if (num_tasks % tasks_per_replica != 0) { + num_tasks = (num_tasks / tasks_per_replica + 1) * tasks_per_replica; + } + std::vector dense_host_ports; + dense_host_ports.resize(num_tasks); + for (const auto& p : sparse_host_ports) { + dense_host_ports[p.first] = p.second; + } + return dense_host_ports; + } + + SharedGrpcChannelPtr FindChannelOnce(const string& target) override { + const string host_port = TranslateTask(target); + if (host_port.empty()) { + LOG(WARNING) << "Could not find channel for target: " << target; + return nullptr; + } + return NewHostPortGrpcChannel(host_port); + } + + private: + const string job_id_; + const std::vector host_ports_; + const int tasks_per_replica_; + TF_DISALLOW_COPY_AND_ASSIGN(HostPortsGrpcChannelCache); +}; + +GrpcChannelCache* NewHostPortsGrpcChannelCache( + const string& job_id, const std::vector& host_ports, + int tasks_per_replica) { + return new HostPortsGrpcChannelCache(job_id, host_ports, tasks_per_replica); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h new file mode 100644 index 00000000000..f3667a567e3 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h @@ -0,0 +1,98 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ + +#include +#include +#include +#include + +#include "external/grpc/include/grpc++/grpc++.h" + +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" + +namespace tensorflow { + +// Consolidated parameter structure to ease use of generic interfaces. +// +// Each job_id requires: +// - a list of host:port (or sparse list of index:host:port) +// - the number of tasks per replica +class GrpcChannelSpec { + public: + struct HostPortsJob { + string job_id; + std::vector host_ports; + int tasks_per_replica; + }; + + Status AddHostPortsJob(const string& job_id, + const std::vector& host_ports, + int tasks_per_replica); + + const std::vector& host_ports_jobs() const { + return host_ports_jobs_; + } + + private: + std::vector host_ports_jobs_; + std::set job_ids_; +}; + +class GrpcChannelCache { + public: + virtual ~GrpcChannelCache() {} + + // Populates *workers with names of all workers which this object + // was created to handle. Worker names are in the format + // /job:/task: + // e.g. /job:mnist/task:2 + virtual void ListWorkers(std::vector* workers) = 0; + + // If found, returns a gRPC channel that is connected to the remote + // worker named by 'target'. 'target' is of the following + // format: /job:/task: + // E.g., /job:mnist/task:2 + virtual SharedGrpcChannelPtr FindWorkerChannel(const string& target) = 0; + + // Translates a string in the form `/job:X/task:Z` into a host_port. + virtual string TranslateTask(const string& task) = 0; +}; + +GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& p); + +// Below here are internal-only functions. + +SharedGrpcChannelPtr NewHostPortGrpcChannel(const string& target); + +// Returns a ChannelCache that uses a set of known host:port pairs. E.g., say, +// job_id = 'mnist', 'host_ports' = {"h0:0", "h1:1", ..., "h11:11", "h12:12"}, +// tasks_per_replica = 8, /job:mnist/replica:1/task:3 is mapped to host:port +// "h11:11" (11 = 8 * 1 + 3). +// +// The caller takes ownership of the returned object. +GrpcChannelCache* NewHostPortsGrpcChannelCache( + const string& job_id, const std::vector& host_ports, + int tasks_per_replica); + +// Returns a ChannelCache that is the union of a number of other ChannelCaches. +GrpcChannelCache* NewMultiGrpcChannelCache( + const std::vector& caches); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc new file mode 100644 index 00000000000..a951dc2fcfb --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc @@ -0,0 +1,137 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" + +#include +#include + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +#define IsSameAddrSp DeviceNameUtils::IsSameAddressSpace + +TEST(GrpcChannelTest, IsSameAddressSpace) { + // Same. + EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0", + "/job:mnist/replica:10/task:10/cpu:1")); + EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0", + "/job:mnist/replica:10/task:10/gpu:2")); + EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10", + "/job:mnist/replica:10/task:10/gpu:2")); + EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:1", + "/job:mnist/replica:10/task:10")); + + // Different. + EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:10/task:9/cpu:0", + "/job:mnist/replica:10/task:10/cpu:0")); + EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:9/task:10/cpu:0", + "/job:mnist/replica:10/task:10/cpu:0")); + EXPECT_FALSE(IsSameAddrSp("/job:MNIST/replica:10/task:10/cpu:0", + "/job:mnist/replica:10/task:10/cpu:0")); + + // Invalid names. + EXPECT_FALSE(IsSameAddrSp("random_invalid_target", "random_invalid_target")); + EXPECT_FALSE(IsSameAddrSp("/job:/replica:10/task:10/cpu:0", + "/job:/replica:10/task:10/cpu:1")); + EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:xx/task:10/cpu:0", + "/job:mnist/replica:xx/task:10/cpu:1")); + EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:10/task:yy/cpu:0", + "/job:mnist/replica:10/task:yy/cpu:1")); +} + +TEST(GrpcChannelTest, HostPorts) { + std::unique_ptr cc(NewHostPortsGrpcChannelCache( + "mnist", {"a:1", "b:2", "c:3", "d:4", "e:5", "f:6"}, 2)); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("invalid_target")); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:other/replica:0/task:0")); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:2")); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:3/task:0")); + + { + // NOTE(mrry): The gRPC channel doesn't expose the target, so we + // can't compare it for equality. + auto a_1_1 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0"); + auto a_1_2 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0"); + + auto d_4_1 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1"); + auto d_4_2 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1"); + + auto e_5_1 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0"); + auto e_5_2 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0"); + + EXPECT_EQ(a_1_1.get(), a_1_2.get()); + EXPECT_EQ(d_4_1.get(), d_4_2.get()); + EXPECT_EQ(e_5_1.get(), e_5_2.get()); + + EXPECT_NE(a_1_1.get(), d_4_2.get()); + EXPECT_NE(a_1_1.get(), e_5_2.get()); + EXPECT_NE(d_4_1.get(), e_5_2.get()); + } + + std::vector workers; + cc->ListWorkers(&workers); + EXPECT_EQ(std::vector({"/job:mnist/replica:0/task:0", + "/job:mnist/replica:0/task:1", + "/job:mnist/replica:1/task:0", + "/job:mnist/replica:1/task:1", + "/job:mnist/replica:2/task:0", + "/job:mnist/replica:2/task:1"}), + workers); +} + +TEST(GrpcChannelTest, SparseHostPorts) { + std::unique_ptr cc( + NewHostPortsGrpcChannelCache("mnist", {"0:a:1", "3:d:4", "4:e:5"}, 2)); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("invalid_target")); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:other/replica:0/task:0")); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:1")); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:2")); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:1/task:0")); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:2/task:1")); + EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:3/task:0")); + + { + // NOTE(mrry): The gRPC channel doesn't expose the target, so we + // can't compare it for equality. + auto a_1_1 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0"); + auto a_1_2 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0"); + + auto d_4_1 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1"); + auto d_4_2 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1"); + + auto e_5_1 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0"); + auto e_5_2 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0"); + + EXPECT_EQ(a_1_1.get(), a_1_2.get()); + EXPECT_EQ(d_4_1.get(), d_4_2.get()); + EXPECT_EQ(e_5_1.get(), e_5_2.get()); + + EXPECT_NE(a_1_1.get(), d_4_2.get()); + EXPECT_NE(a_1_1.get(), e_5_2.get()); + EXPECT_NE(d_4_1.get(), e_5_2.get()); + } + + std::vector workers; + cc->ListWorkers(&workers); + EXPECT_EQ(std::vector({"/job:mnist/replica:0/task:0", + "/job:mnist/replica:1/task:1", + "/job:mnist/replica:2/task:0"}), + workers); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h new file mode 100644 index 00000000000..300481303b9 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h @@ -0,0 +1,56 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ + +#include "external/grpc/include/grpc++/grpc++.h" + +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// Represents a pending asynchronous client call as a tag that can be +// stored in a `grpc::CompletionQueue`. +class GrpcClientCQTag { + public: + GrpcClientCQTag(::grpc::ClientContext* context, StatusCallback cb) + : context_(context), cb_(cb) {} + ~GrpcClientCQTag() { delete context_; } + + void OnCompleted(bool ok) { + if (!ok) { + VLOG(2) << "Call returned with non-ok status: " + << status_.error_message(); + } + cb_(FromGrpcStatus(status_)); + } + + ::grpc::ClientContext* context() { return context_; } + ::grpc::Status* status() { return &status_; } + + private: + ::grpc::ClientContext* context_; + ::grpc::Status status_; + StatusCallback cb_; + + TF_DISALLOW_COPY_AND_ASSIGN(GrpcClientCQTag); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc new file mode 100644 index 00000000000..b8d50c5695d --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -0,0 +1,181 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// GrpcMasterService implements the RPC service MasterSerivce. +// +// A GrpcMasterService maintains the state of live graph computation +// sessions, each session orchestrates both local and remote devices +// to carry out the graph computation. +// +// A GrpcMasterService knows ahead of time local devices available as +// client devices. +// +// A GrpcMasterService discovers remote devices in the background and +// keeps track of statistics of those remote devices. +// +// Each session analyses the graph, places nodes across available +// devices, and ultimately drives the graph computation by initiating +// RunGraph on workers. +#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h" + +#include "external/grpc/include/grpc++/server_builder.h" + +#include "tensorflow/core/distributed_runtime/master.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/protobuf/master_service.grpc.pb.h" + +namespace tensorflow { + +class GrpcMasterService : public AsyncServiceInterface { + public: + GrpcMasterService(MasterEnv* env, ::grpc::ServerBuilder* builder) + : master_impl_(new Master(env, 0.0)) { + builder->RegisterService(&master_service_); + cq_ = builder->AddCompletionQueue().release(); + } + + ~GrpcMasterService() { + delete cq_; + delete master_impl_; + } + +// This macro creates a new request for the given RPC method name +// (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on +// `this->cq_`. +// +// This macro is invoked one or more times for each RPC method to +// ensure that there are sufficient completion queue entries to +// handle incoming requests without blocking. +// +// The implementation of the request handler for each RPC method +// must ensure that it calls ENQUEUE_REQUEST() for that RPC method, +// to keep accepting new requests. +#define ENQUEUE_REQUEST(method) \ + do { \ + Call:: \ + EnqueueRequest(&master_service_, cq_, \ + &grpc::MasterService::AsyncService::Request##method, \ + &GrpcMasterService::method##Handler); \ + } while (0) + + void HandleRPCsLoop() { + ENQUEUE_REQUEST(CreateSession); + ENQUEUE_REQUEST(ExtendSession); + for (int i = 0; i < 100; ++i) { + ENQUEUE_REQUEST(RunStep); + } + ENQUEUE_REQUEST(CloseSession); + ENQUEUE_REQUEST(ListDevices); + ENQUEUE_REQUEST(Reset); + + void* tag; + bool ok; + while (cq_->Next(&tag, &ok)) { + CHECK(ok); + UntypedCall::Tag* callback_tag = + static_cast::Tag*>(tag); + callback_tag->OnCompleted(this, ok); + delete callback_tag; + } + } + + private: + Master* master_impl_; // Owned. + ::grpc::ServerCompletionQueue* cq_; // Owned. + grpc::MasterService::AsyncService master_service_; + + template + using MasterCall = Call; + + // RPC handler for creating a session. + void CreateSessionHandler( + MasterCall* call) { + master_impl_->CreateSession(&call->request, &call->response, + [call](const Status& status) { + call->SendResponse(ToGrpcStatus(status)); + }); + ENQUEUE_REQUEST(CreateSession); + } + + // RPC handler for extending a session. + void ExtendSessionHandler( + MasterCall* call) { + master_impl_->ExtendSession(&call->request, &call->response, + [call](const Status& status) { + call->SendResponse(ToGrpcStatus(status)); + }); + ENQUEUE_REQUEST(ExtendSession); + } + + // RPC handler for running one step in a session. + void RunStepHandler(MasterCall* call) { + CallOptions* call_opts = new CallOptions; + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + master_impl_->RunStep(call_opts, &call->request, &call->response, + [call, call_opts](const Status& status) { + call->ClearCancelCallback(); + delete call_opts; + call->SendResponse(ToGrpcStatus(status)); + }); + ENQUEUE_REQUEST(RunStep); + } + + // RPC handler for deleting a session. + void CloseSessionHandler( + MasterCall* call) { + master_impl_->CloseSession(&call->request, &call->response, + [call](const Status& status) { + call->SendResponse(ToGrpcStatus(status)); + }); + ENQUEUE_REQUEST(CloseSession); + } + + // RPC handler for listing devices. + void ListDevicesHandler( + MasterCall* call) { + master_impl_->ListDevices(&call->request, &call->response, + [call](const Status& status) { + call->SendResponse(ToGrpcStatus(status)); + }); + ENQUEUE_REQUEST(ListDevices); + } + + // RPC handler for resetting all sessions. + void ResetHandler(MasterCall* call) { + master_impl_->Reset(&call->request, &call->response, + [call](const Status& status) { + call->SendResponse(ToGrpcStatus(status)); + }); + ENQUEUE_REQUEST(Reset); + } +#undef ENQUEUE_REQUEST + + TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService); +}; + +AsyncServiceInterface* NewGrpcMasterService(MasterEnv* env, + ::grpc::ServerBuilder* builder) { + CHECK(!env->local_devices.empty()); + return new GrpcMasterService(env, builder); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h new file mode 100644 index 00000000000..d23a3f3ed32 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h @@ -0,0 +1,33 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ + +namespace grpc { +class ServerBuilder; +} // namespace grpc + +namespace tensorflow { + +class AsyncServiceInterface; +class MasterEnv; + +AsyncServiceInterface* NewGrpcMasterService(MasterEnv* env, + ::grpc::ServerBuilder* builder); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc new file mode 100644 index 00000000000..e358aed31f2 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -0,0 +1,79 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h" + +#include "tensorflow/core/distributed_runtime/master_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/protobuf/master_service.grpc.pb.h" + +namespace tensorflow { + +// GrpcRemoteMaster is an implementation of the MasterInterface +// that uses gRPC to talk to the Master service. +class GrpcRemoteMaster : public MasterInterface { + public: + explicit GrpcRemoteMaster(SharedGrpcChannelPtr client_channel) + : stub_(grpc::MasterService::NewStub(client_channel)) {} + + ~GrpcRemoteMaster() override {} + + Status CreateSession(const CreateSessionRequest* request, + CreateSessionResponse* response) override { + ::grpc::ClientContext ctx; + return FromGrpcStatus(stub_->CreateSession(&ctx, *request, response)); + } + + Status ExtendSession(const ExtendSessionRequest* request, + ExtendSessionResponse* response) override { + ::grpc::ClientContext ctx; + return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response)); + } + + Status RunStep(const RunStepRequest* request, + RunStepResponse* response) override { + ::grpc::ClientContext ctx; + return FromGrpcStatus(stub_->RunStep(&ctx, *request, response)); + } + + Status CloseSession(const CloseSessionRequest* request, + CloseSessionResponse* response) override { + ::grpc::ClientContext ctx; + return FromGrpcStatus(stub_->CloseSession(&ctx, *request, response)); + } + + Status ListDevices(const ListDevicesRequest* request, + ListDevicesResponse* response) override { + ::grpc::ClientContext ctx; + return FromGrpcStatus(stub_->ListDevices(&ctx, *request, response)); + } + + Status Reset(const ResetRequest* request, ResetResponse* response) override { + ::grpc::ClientContext ctx; + return FromGrpcStatus(stub_->Reset(&ctx, *request, response)); + } + + private: + std::unique_ptr stub_; +}; + +MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel) { + return new GrpcRemoteMaster(channel); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h new file mode 100644 index 00000000000..461e4ca0bdc --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h @@ -0,0 +1,27 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ + +#include "tensorflow/core/distributed_runtime/master_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" + +namespace tensorflow { +// Returns a MasterInterface wrapped around the gRPC channel `channel`. +MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel); +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc new file mode 100644 index 00000000000..0040631aac4 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -0,0 +1,203 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h" + +#include "external/grpc/include/grpc++/grpc++.h" + +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" +#include "tensorflow/core/distributed_runtime/worker_cache_logger.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/protobuf/worker.pb.h" +#include "tensorflow/core/protobuf/worker_service.grpc.pb.h" + +namespace tensorflow { + +class GrpcRemoteWorker : public WorkerInterface { + public: + explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel, + ::grpc::CompletionQueue* completion_queue, + WorkerCacheLogger* logger) + : stub_(grpc::WorkerService::NewStub(channel)), + cq_(completion_queue), + logger_(logger) {} + + ~GrpcRemoteWorker() override {} + + void GetStatusAsync(const GetStatusRequest* request, + GetStatusResponse* response, + StatusCallback done) override { + IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncGetStatus, + done); + } + + void RegisterGraphAsync(const RegisterGraphRequest* request, + RegisterGraphResponse* response, + StatusCallback done) override { + IssueRequest(request, response, + &grpc::WorkerService::Stub::AsyncRegisterGraph, done); + } + + void DeregisterGraphAsync(const DeregisterGraphRequest* request, + DeregisterGraphResponse* response, + StatusCallback done) override { + IssueRequest(request, response, + &grpc::WorkerService::Stub::AsyncDeregisterGraph, done); + } + + void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request, + RunGraphResponse* response, StatusCallback done) override { + IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncRunGraph, + done, call_opts); + } + + void CleanupGraphAsync(const CleanupGraphRequest* request, + CleanupGraphResponse* response, + StatusCallback done) override { + IssueRequest(request, response, + &grpc::WorkerService::Stub::AsyncCleanupGraph, done); + } + + void CleanupAllAsync(const CleanupAllRequest* request, + CleanupAllResponse* response, + StatusCallback done) override { + IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncCleanupAll, + done); + } + + void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request, + RecvTensorResponse* response, + TensorBufAllocator allocator, + StatusCallback done) override { + VLOG(1) << "RecvTensorAsync req: " << request->DebugString(); + int64 start_usec = Env::Default()->NowMicros(); + // Don't propagate dma_ok over gRPC. + RecvTensorRequest* req_copy = nullptr; + if (request->dma_ok()) { + req_copy = new RecvTensorRequest; + *req_copy = *request; + req_copy->set_dma_ok(false); + } + // Type-specialized logging for this method. + StatusCallback logging_callback = [this, request, req_copy, response, done, + start_usec](Status s) { + if (logger_->LoggingActive()) { + int64 end_usec = Env::Default()->NowMicros(); + int64 step_id = request->step_id(); + int64 bytes = response->tensor().ByteSize(); + int64 send_start_usec = start_usec; + // If a send start time was reported by the other side, use + // that instead. Maybe we should mark the display if we're using + // our local time instead of the remote start time? + if (response->send_start_micros()) { + // send_start_micros is the timestamp taken when the remote + // machine began to send the RecvTensor response. + // Due to clock skew between source and dest machines, it is + // possible that send_start_micros can be larger than end_usec or + // less than start_usec. + // To respect causality, we enforce the invariants that the RecvTensor + // response can not have been sent before the RecvTensor request, and + // must have been sent before it was received. + send_start_usec = std::max(start_usec, response->send_start_micros()); + send_start_usec = std::min(send_start_usec, end_usec - 1); + } + const string& key = request->rendezvous_key(); + std::vector key_parts = str_util::Split(key, ';'); + if (key_parts.size() != 5) { + LOG(WARNING) << "Bad key: " << key; + } else { + logger_->RecordRecvTensor(step_id, send_start_usec, end_usec, + key_parts[3], // tensor name + key_parts[0], // src_device + key_parts[2], // dst_device + bytes); + } + } + VLOG(2) << "done callback, req: " << request->DebugString() + << " response " << response->DebugString(); + delete req_copy; + done(s); + }; + + IssueRequest(req_copy ? req_copy : request, response, + &grpc::WorkerService::Stub::AsyncRecvTensor, logging_callback, + call_opts); + } + + void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, + StatusCallback done) override { + IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncLogging, + done); + } + + void TracingAsync(const TracingRequest* request, TracingResponse* response, + StatusCallback done) override { + IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncTracing, + done); + } + + private: + template + using AsyncMethod = + std::unique_ptr<::grpc::ClientAsyncResponseReader> ( + grpc::WorkerService::Stub::*)(::grpc::ClientContext*, + const RequestMessage&, + ::grpc::CompletionQueue*); + + // Utility method for issuing a generic asynchronous request. The + // given callback, `done`, will be called when the RPC completes. + template + void IssueRequest(const RequestMessage* request, ResponseMessage* response, + AsyncMethod async_method, + StatusCallback done, CallOptions* call_opts = nullptr) { + ::grpc::ClientContext* context = new ::grpc::ClientContext; + if (call_opts) { + call_opts->SetCancelCallback([context]() { context->TryCancel(); }); + } + auto rpc = (stub_.get()->*async_method)(context, *request, cq_).release(); + GrpcClientCQTag* tag = + new GrpcClientCQTag(context, [rpc, done, call_opts](Status s) { + if (call_opts) { + call_opts->ClearCancelCallback(); + } + delete rpc; + done(s); + }); + rpc->Finish(response, tag->status(), tag); + } + + std::unique_ptr stub_; + ::grpc::CompletionQueue* cq_; + + // Support for logging. + WorkerCacheLogger* logger_; + bool retry_unavailable_; + + TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker); +}; + +WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, + ::grpc::CompletionQueue* completion_queue, + WorkerCacheLogger* logger) { + return new GrpcRemoteWorker(channel, completion_queue, logger); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h new file mode 100644 index 00000000000..dfb72bdde24 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h @@ -0,0 +1,38 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ +#define THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ + +#include + +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" + +namespace grpc { +class CompletionQueue; +} + +namespace tensorflow { + +class WorkerCacheLogger; +class WorkerInterface; + +WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, + ::grpc::CompletionQueue* completion_queue, + WorkerCacheLogger* logger); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc new file mode 100644 index 00000000000..ddac7fd2cd7 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -0,0 +1,116 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" + +#include + +#include "external/grpc/include/grpc++/grpc++.h" +#include "external/grpc/include/grpc++/security/credentials.h" +#include "external/grpc/include/grpc++/server_builder.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/distributed_runtime/graph_mgr.h" +#include "tensorflow/core/distributed_runtime/master_env.h" +#include "tensorflow/core/distributed_runtime/master_session.h" +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { + +void StartTensorFlowServer(const GrpcServerOptions& options) { + // The Thread destructor waits until all the thread terminates is + // done (i.e. forever). + std::unique_ptr launcher_thread(Env::Default()->StartThread( + ThreadOptions(), "TF_service_launcher", [options]() { + // Configure the MasterEnv and WorkerEnv, which provide service-global + // context for the master and worker services, respectively. + + // The master and worker share the same worker cache (for RPC + // connections to other workers) and devices (so that the master + // may run some ops locally as a "client" device). The master + // requires a device to serve as a "client device", so that remote + // devices can copy the feeds from the master. + MasterEnv master_env; + WorkerEnv worker_env; + master_env.env = Env::Default(); + worker_env.env = Env::Default(); + + // Configure shared devices between master and worker. + string name_prefix = + strings::StrCat("/job:", options.job_name, "/replica:0", "/task:", + options.task_index); + DeviceFactory::AddDevices(options.default_session_options, name_prefix, + &master_env.local_devices); + worker_env.device_mgr = new DeviceMgr(master_env.local_devices); + string unused; + CHECK(DeviceNameUtils::SplitDeviceName( + master_env.local_devices[0]->name(), &worker_env.worker_name, + &unused)); + + GrpcChannelCache* channel_cache = + NewGrpcChannelCache(options.channel_spec); + int port; + const std::vector host_port = + str_util::Split(channel_cache->TranslateTask(name_prefix), ':'); + CHECK(str_util::NumericParse32(host_port[1], &port)); + + worker_env.worker_cache = NewGrpcWorkerCache(channel_cache); + + // Finish setting up master environment. + master_env.ops = OpRegistry::Global(); + master_env.worker_cache = worker_env.worker_cache; + master_env.master_session_factory = internal::NewMasterSession; + + // Finish setting up worker environment. + worker_env.graph_mgr = new GraphMgr(&worker_env); + worker_env.rendezvous_mgr = new RpcRendezvousMgr(&worker_env); + worker_env.compute_pool = ComputePool(options.default_session_options); + + // Build the gRPC server that will host both the master and the + // worker services. + ::grpc::ServerBuilder builder; + builder.AddListeningPort(strings::StrCat("0.0.0.0:", port), + ::grpc::InsecureServerCredentials()); + auto master_service = NewGrpcMasterService(&master_env, &builder); + auto worker_service = NewGrpcWorkerService(&worker_env, &builder); + auto server_ = builder.BuildAndStart(); + + // Start threads to handle the incoming RPCs for the master and worker. + // NOTE(mrry): The Thread destructor waits until the thread terminates + // (i.e. forever in this case). + std::unique_ptr master_thread(Env::Default()->StartThread( + ThreadOptions(), "TF_master_service", + [master_service]() { master_service->HandleRPCsLoop(); })); + std::unique_ptr worker_thread(Env::Default()->StartThread( + ThreadOptions(), "TF_worker_service", + [worker_service]() { worker_service->HandleRPCsLoop(); })); + })); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h new file mode 100644 index 00000000000..59abb31a15d --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -0,0 +1,53 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// Defines the configuration for a single task (typically a process) +// that is part of a gRPC-based TensorFlow cluster. +struct GrpcServerOptions { + // This identity of the job to which this task belongs. The names + // of the devices in this task will be prefixed with + // "/job:/task:" + string job_name; + int32 task_index = 0; + + // A channel specification, which defines (i) the set of jobs that + // comprise the cluster, and (ii) within each job, the endpoints + // exposed by each task. NOTE: This spec also defines the endpoint + // on which this task will listen. + GrpcChannelSpec channel_spec; + + // SessionOptions that will be used as defaults when configuring + // sessions in this task. `default_session_options.target` is + // ignored. + SessionOptions default_session_options; +}; + +// Starts a gRPC-based TensorFlow server with the given options. +// This function will not return. +void StartTensorFlowServer(const GrpcServerOptions& options); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc new file mode 100644 index 00000000000..6924fc55377 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -0,0 +1,233 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h" + +#include + +#include "tensorflow/core/common_runtime/session_factory.h" +#include "tensorflow/core/distributed_runtime/master_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/protobuf/master.pb.h" + +namespace tensorflow { + +const size_t kSchemePrefix = sizeof("grpc://") - 1; + +GrpcSession::GrpcSession(const SessionOptions& options) + : options_(options), + master_(NewGrpcMaster( + NewHostPortGrpcChannel(options.target.substr(kSchemePrefix)))), + current_graph_version_(-1) {} + +GrpcSession::~GrpcSession() {} + +namespace { +// Re-encodes constant represented in tensor proto into +// tensor_content, which is slightly better (less copies and lower peak +// memory usage) when used with rpc subsystems. +void ReEncodeConsts(GraphDef* gdef) { + for (NodeDef& ndef : *(gdef->mutable_node())) { + if (ndef.op() == "Const") { + TensorProto* proto = nullptr; + for (auto& attr : *ndef.mutable_attr()) { + if (attr.first == "value") { + proto = attr.second.mutable_tensor(); + } + } + if (proto != nullptr && proto->tensor_content().empty() && + proto->ByteSize() > 64) { + // If the constant is encoded with repeated proto fields and + // it is moderate large, we re-encode it in tensor_content as + // a Cord. This is mildly helpful for reducing the peak memory + // usage on the server side where GraphDef/NodeDef are copied + // quite often. + Tensor parsed(proto->dtype()); + if (parsed.FromProto(*proto)) { + parsed.AsProtoTensorContent(proto); + } + } + } + } +} +} // namespace + +Status GrpcSession::Create(const GraphDef& graph) { + if (!handle_.empty()) { + return errors::InvalidArgument("A session is alive."); + } + CreateSessionRequest req; + *req.mutable_config() = options_.config; + *req.mutable_graph_def() = graph; + ReEncodeConsts(req.mutable_graph_def()); + CreateSessionResponse resp; + Status s = master_->CreateSession(&req, &resp); + if (s.ok()) { + mutex_lock l(mu_); + swap(handle_, *(resp.mutable_session_handle())); + current_graph_version_ = resp.graph_version(); + } + return s; +} + +Status GrpcSession::Extend(const GraphDef& graph) { + if (handle_.empty()) { + // Session was unitialized, so simply initialize the session with 'graph'. + return Create(graph); + } + mutex_lock l(mu_); + ExtendSessionRequest req; + req.set_session_handle(handle_); + *req.mutable_graph_def() = graph; + req.set_current_graph_version(current_graph_version_); + ExtendSessionResponse resp; + Status s = master_->ExtendSession(&req, &resp); + if (s.ok()) { + current_graph_version_ = resp.new_graph_version(); + } + return s; +} + +Status GrpcSession::Run(const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs) { + // Convert to proto + RunStepRequest req; + RunStepResponse resp; + + for (const auto& it : inputs) { + Tensor input_tensor = it.second; + auto feed = req.add_feed(); + feed->set_name(it.first); + TensorProto* proto = feed->mutable_tensor(); + input_tensor.AsProtoTensorContent(proto); + } + + // Build an index from fetch tensor name to offset. + std::unordered_map output_name_to_offset; + for (const string& output_name : output_names) { + req.add_fetch(output_name); + output_name_to_offset.insert( + std::make_pair(output_name, output_name_to_offset.size())); + } + for (const string& target : target_nodes) { + req.add_target(target); + } + + TF_RETURN_IF_ERROR(RunProto(&req, &resp)); + + if (!output_names.empty()) { + outputs->resize(output_names.size()); + } + + // Convert response back to Tensors in the correct order. + for (const NamedTensorProto& tensor : resp.tensor()) { + auto fetch_it = output_name_to_offset.find(tensor.name()); + if (fetch_it == output_name_to_offset.end()) { + return errors::Internal("Received response for unrequested fetch: ", + tensor.name()); + } + + Tensor output; + if (!output.FromProto(tensor.tensor())) { + return errors::InvalidArgument("Could not parse returned proto for ", + tensor.name()); + } + + (*outputs)[fetch_it->second] = output; + } + + return Status::OK(); +} + +Status GrpcSession::RunProto(RunStepRequest* req, RunStepResponse* resp) { + if (handle_.empty()) { + return errors::InvalidArgument("A session is not created yet...."); + } + + req->set_session_handle(handle_); + return master_->RunStep(req, resp); +} + +Status GrpcSession::PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) { + return errors::Internal("Partial run is not supported for remote session."); +} + +Status GrpcSession::PRun(const string& handle, + const std::vector>& inputs, + const std::vector& output_names, + std::vector* outputs) { + return errors::Internal("Partial run is not supported for remote session."); +} + +Status GrpcSession::Close() { + if (handle_.empty()) { + return errors::InvalidArgument("A session is not created yet...."); + } + CloseSessionRequest req; + req.set_session_handle(handle_); + handle_.clear(); + CloseSessionResponse resp; + return master_->CloseSession(&req, &resp); +} + +std::vector GrpcSession::ListDevices() { + std::vector devices; + + ListDevicesRequest req; + ListDevicesResponse resp; + Status s = master_->ListDevices(&req, &resp); + if (!s.ok()) { + LOG(ERROR) << "Could not list devices: " << s; + return devices; + } + + for (const auto& device_attr : resp.local_device()) { + devices.push_back(device_attr); + } + for (const auto& device_attr : resp.remote_device()) { + devices.push_back(device_attr); + } + + return devices; +} + +class GrpcSessionFactory : public SessionFactory { + public: + bool AcceptsOptions(const SessionOptions& options) override { + return StringPiece(options.target).starts_with("grpc://"); + } + + Session* NewSession(const SessionOptions& options) override { + return new GrpcSession(options); + } +}; + +class GrpcSessionRegistrar { + public: + GrpcSessionRegistrar() { + SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory()); + } +}; +static GrpcSessionRegistrar registrar; + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h new file mode 100644 index 00000000000..9bc6034ba61 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h @@ -0,0 +1,97 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class MasterInterface; + +// A Session instance lets the caller drive a TensorFlow graph +// computation on potentially remote sets of devices. This is a thin +// wrapper around tensorflow::grpc::MasterService. +// +// Multiple threads must synchronize their accesses to a single +// session. +class GrpcSession : public Session { + public: + // Do not use; just present for easier swig wrapping. + explicit GrpcSession(const SessionOptions& options); + + ~GrpcSession() override; + + // Creates a session with the "target". The session carries out + // the graph computation defined by "graph", and will have version + // number "initial_version". + Status Create(const GraphDef& graph) override; + + Status Run(const std::vector >& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs) override; + + Status Extend(const GraphDef& graph) override; + Status Close() override; + + // NOTE: This API is still experimental and may change. + ::tensorflow::Status PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) override; + + // NOTE: This API is still experimental and may change. + ::tensorflow::Status PRun( + const string& handle, + const std::vector >& inputs, + const std::vector& output_names, + std::vector* outputs) override; + + std::vector ListDevices(); + + private: + SessionOptions options_; + std::unique_ptr master_; + mutex mu_; + + // handle_ returned by the master to identify this session. + string handle_; + + // The current version of the graph. + int64 current_graph_version_ GUARDED_BY(mu_); + + Status RunProto(RunStepRequest* req, RunStepResponse* resp); + + TF_DISALLOW_COPY_AND_ASSIGN(GrpcSession); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc new file mode 100644 index 00000000000..86a9b07c2cb --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -0,0 +1,750 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h" + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/default_device.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/port.h" + +namespace tensorflow { + +static SessionOptions Devices(int num_cpus, int num_gpus) { + SessionOptions result; + (*result.config.mutable_device_count())["CPU"] = num_cpus; + (*result.config.mutable_device_count())["GPU"] = num_gpus; + return result; +} + +void CreateGraphDef(GraphDef* graph_def, string node_names[3]) { + Graph graph(OpRegistry::Global()); + + Tensor a_tensor(DT_FLOAT, TensorShape({1, 2})); + test::FillValues(&a_tensor, {1, 2}); + Node* a = test::graph::Constant(&graph, a_tensor); + node_names[0] = a->name(); + + Tensor b_tensor(DT_FLOAT, TensorShape({2, 1})); + test::FillValues(&b_tensor, {2, 1}); + Node* b = test::graph::Constant(&graph, b_tensor); + node_names[1] = b->name(); + + Node* c = test::graph::Matmul(&graph, a, b, false, false); + node_names[2] = c->name(); + + test::graph::ToGraphDef(&graph, graph_def); +} + +// Asserts that "val" is a single float tensor. The only float is +// "expected_val". +static void IsSingleFloatValue(const Tensor& val, float expected_val) { + ASSERT_EQ(val.dtype(), DT_FLOAT); + ASSERT_EQ(val.NumElements(), 1); + ASSERT_EQ(val.flat()(0), expected_val); +} + +static SessionOptions Options(const string& target, int placement_period) { + SessionOptions options; + // NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target + // string. + options.target = strings::StrCat("grpc://", target); + options.config.set_placement_period(placement_period); + return options; +} + +static Session* NewRemote(const SessionOptions& options) { + return CHECK_NOTNULL(NewSession(options)); +} + +TEST(GrpcSessionTest, BasicNonProtoAPI) { + GraphDef graph; + string node_names[3]; + // c = a * b + CreateGraphDef(&graph, node_names); + + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + + std::unique_ptr session( + NewRemote(Options(cluster->targets()[0], 1))); + ASSERT_TRUE(session != nullptr); + + for (int iters = 0; iters < 25; ++iters) { + TF_CHECK_OK(session->Create(graph)); + { + std::vector> inputs; + TF_CHECK_OK(session->Run(inputs, {}, {}, {})); + } + { + // Just run to target node + std::vector> inputs; + std::vector targets = {node_names[2]}; + TF_CHECK_OK(session->Run(inputs, {}, targets, nullptr)); + } + { + // Run to a target node and a real tensor + std::vector> inputs; + std::vector names = {node_names[2] + ":0"}; + std::vector targets = {node_names[1]}; + std::vector outputs; + TF_CHECK_OK(session->Run(inputs, names, targets, &outputs)); + ASSERT_TRUE(outputs[0].IsInitialized()); + ASSERT_EQ(4.0, outputs[0].flat()(0)); + } + + TF_CHECK_OK(session->Close()); + } +} + +TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) { + GraphDef graph; + string node_names[3]; + // c = a * b + CreateGraphDef(&graph, node_names); + + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + + std::unique_ptr session( + NewRemote(Options(cluster->targets()[0], 1))); + ASSERT_TRUE(session != nullptr); + ASSERT_TRUE(session->Create(graph).ok()); + + // Test that the order of the output names matches the order of the + // returned Tensors. + std::vector> inputs; + std::vector names = {node_names[2] + ":0", node_names[0] + ":0", + node_names[1] + ":0"}; + + std::vector target_ops = {node_names[1]}; + std::vector outputs; + ASSERT_TRUE(session->Run(inputs, names, target_ops, &outputs).ok()); + ASSERT_TRUE(outputs[0].IsInitialized()); + ASSERT_EQ(4.0, outputs[0].flat()(0)); + ASSERT_TRUE(outputs[1].IsInitialized()); + ASSERT_EQ(1.0, outputs[1].flat()(0)); + ASSERT_TRUE(outputs[2].IsInitialized()); + ASSERT_EQ(2.0, outputs[2].flat()(0)); + ASSERT_TRUE(session->Close().ok()); +} + +TEST(GrpcSessionTest, NonLocalWithFilters) { + GraphDef graph; + string node_names[3]; + // c = a * b + CreateGraphDef(&graph, node_names); + + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + + SessionOptions options; + options.target = strings::StrCat("grpc://", cluster->targets()[0]); + options.config.add_device_filters(cluster->devices()[0].name()); + + std::unique_ptr session(NewRemote(options)); + ASSERT_TRUE(session != nullptr); + + { + GraphDef graph_copy(graph); + graph::SetDefaultDevice(cluster->devices()[0].name(), &graph_copy); + TF_CHECK_OK(session->Create(graph_copy)); + TF_CHECK_OK(session->Run({}, {}, {}, nullptr)); + TF_CHECK_OK(session->Close()); + } + { + GraphDef graph_copy(graph); + graph::SetDefaultDevice(cluster->devices()[1].name(), &graph_copy); + TF_CHECK_OK(session->Create(graph_copy)); + auto status = session->Run({}, {}, {}, nullptr); + EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code()); + TF_CHECK_OK(session->Close()); + } +} + +// A = [3 2; -1 0]; x = rand(2, 1); We want to compute the largest +// eigenvalue for A, which is 2.0. Iteratively, we do +// repeat x = y / y.norm(); y = A * x; end +// At the end, we expect "lambda" converges to 2.0. +void FindMaxEigen(const string& target) { + Graph graph(OpRegistry::Global()); + + Tensor a_tensor(DT_FLOAT, TensorShape({2, 2})); + // Store rows [3, 2] and [-1, 0] in row major format. + test::FillValues(&a_tensor, {3, 2, -1, 0}); + Node* a = test::graph::Constant(&graph, a_tensor); + + // x is from the feed. + Tensor x_tensor(DT_FLOAT, TensorShape({2, 1})); + test::FillValues(&x_tensor, {0, 0}); + Node* x = test::graph::Constant(&graph, x_tensor); + + // y = A * x + Node* y = test::graph::Matmul(&graph, a, x, false, false); + + // y2 = y.^2 + Node* y2 = test::graph::Unary(&graph, "Square", y); + + // const tensor for reduction + Tensor rdim_tensor(DT_INT32, TensorShape({})); + rdim_tensor.scalar()() = 0; + Node* rdim = test::graph::Constant(&graph, rdim_tensor); + + // y2_sum = sum(y2) + Node* y2_sum = test::graph::Reduce(&graph, "Sum", y2, rdim); + + // y_norm = sqrt(y2_sum) + Node* y_norm = test::graph::Unary(&graph, "Sqrt", y2_sum); + + // y_normalized = y ./ y_norm + Node* y_normalized = test::graph::Binary(&graph, "Div", y, y_norm); + + GraphDef def; + test::graph::ToGraphDef(&graph, &def); + + std::unique_ptr session(NewRemote(Options(target, 1))); + ASSERT_TRUE(session != nullptr); + TF_CHECK_OK(session->Create(def)); + + // Setup feeds and fetches. + float lambda; + Tensor feed_value(DT_FLOAT, TensorShape({2, 1})); + feed_value.matrix()(0, 0) = -3.1415; + feed_value.matrix()(1, 0) = +2.7183; + + for (int i = 0; i < 25; ++i) { + std::vector outputs; + TF_CHECK_OK(session->Run({{x->name(), feed_value}}, + {y->name(), y_normalized->name()}, {}, &outputs)); + const Tensor& y = outputs[0]; + const Tensor& y_normalized = outputs[1]; + // Print out lambda, x, and y. + CHECK_EQ(2, feed_value.NumElements()); + CHECK_EQ(2, y.NumElements()); + lambda = y.flat()(0) / feed_value.flat()(0); + printf("%06d lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]\n", i, + lambda, feed_value.flat()(0), feed_value.flat()(1), + y.flat()(0), y.flat()(1)); + // Copies y_normalized to *x. + feed_value = y_normalized; + } + EXPECT_NEAR(2.0, lambda, 1e-6); +} + +TEST(FindMaxEigenTest, RemoteDevice) { + std::unique_ptr cluster; + test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster); + FindMaxEigen(cluster->targets()[0]); +} + +void SetDevice(GraphDef* graph, const string& name, const string& dev) { + for (int i = 0; i < graph->node_size(); ++i) { + if (graph->node(i).name() == name) { + graph->mutable_node(i)->set_device(dev); + return; + } + } + LOG(FATAL) << "Name '" << name << "' not found."; +} + +TEST(GrpcSessionTest, MultiDevices) { + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + + Graph graph(OpRegistry::Global()); + const int kSize = 1048576; + + // c = a * b = 2 * 3 * kSize + Tensor a_tensor(DT_FLOAT, TensorShape({1, kSize})); + Tensor b_tensor(DT_FLOAT, TensorShape({kSize, 1})); + for (int i = 0; i < kSize; ++i) { + a_tensor.flat()(i) = 2; + b_tensor.flat()(i) = 3; + } + Node* a = test::graph::Constant(&graph, a_tensor); + Node* b = test::graph::Constant(&graph, b_tensor); + Node* c = test::graph::Matmul(&graph, a, b, false, false); + + GraphDef def; + test::graph::ToGraphDef(&graph, &def); + + // In this test, we force each node (a, b, c) on every possible device. + // We test all possible cases. + for (const auto& a_dev : cluster->devices()) { + for (const auto& b_dev : cluster->devices()) { + for (const auto& c_dev : cluster->devices()) { + LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name() + << " c: " << c_dev.name(); + + SetDevice(&def, a->name(), a_dev.name()); + SetDevice(&def, b->name(), b_dev.name()); + SetDevice(&def, c->name(), c_dev.name()); + + std::unique_ptr session( + NewRemote(Options(cluster->targets()[0], 1000))); + ASSERT_TRUE(session != nullptr); + TF_CHECK_OK(session->Create(def)); + { + std::vector outputs; + TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs)); + ASSERT_EQ(1, outputs.size()); + IsSingleFloatValue(outputs[0], 6.0 * kSize); + } + TF_CHECK_OK(session->Close()); + } + } + } +} + +TEST(GrpcSessionTest, MultiDevices_String) { + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster)); + std::unique_ptr session( + NewRemote(Options(cluster->targets()[0], 1000))); + ASSERT_TRUE(session != nullptr); + + // b = a + Graph graph(OpRegistry::Global()); + Tensor a_tensor(DT_STRING, TensorShape({2, 2})); + for (int i = 0; i < 4; ++i) { + a_tensor.flat()(i) = "hello, world"; + } + Node* a = test::graph::Constant(&graph, a_tensor); + Node* b = test::graph::Identity(&graph, a); + + GraphDef def; + test::graph::ToGraphDef(&graph, &def); + + // In this test, we force each node (a, b) on every possible device. + // We test all possible cases. + for (const auto& a_dev : cluster->devices()) { + for (const auto& b_dev : cluster->devices()) { + LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name(); + SetDevice(&def, a->name(), a_dev.name()); + SetDevice(&def, b->name(), b_dev.name()); + + TF_CHECK_OK(session->Create(def)); + { + std::vector outputs; + Status s = session->Run({}, {b->name()}, {}, &outputs); + if (s.ok()) { + ASSERT_EQ(1, outputs.size()); + ASSERT_EQ(outputs[0].dtype(), DT_STRING); + ASSERT_EQ(outputs[0].NumElements(), 4); + for (int i = 0; i < outputs[0].NumElements(); ++i) { + EXPECT_EQ(outputs[0].flat()(i), "hello, world"); + } + } else { + LOG(ERROR) << "Error: " << s; + ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) || + (b_dev.device_type() == DEVICE_GPU)); + ASSERT_FALSE(s.ok()); + } + } + TF_CHECK_OK(session->Close()); + } + } +} + +TEST(GrpcSessionTest, SendRecv_Node_Naming) { + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 3, &cluster)); + std::unique_ptr session( + NewRemote(Options(cluster->targets()[0], 1))); + ASSERT_TRUE(session != nullptr); + + // This test case needs at least 3 devices. + CHECK_GE(cluster->devices().size(), 3); + const DeviceAttributes& src = cluster->devices()[0]; + const DeviceAttributes& dst0 = cluster->devices()[1]; + const DeviceAttributes& dst1 = cluster->devices()[2]; + LOG(INFO) << "src = " << src.name() << " dst0 = " << dst0.name() + << " dst1 = " << dst1.name(); + + // Within the same session, we compute two subgraphs: + // 1) a on 'src' sends to b on 'dst0'; + // 2) a on 'src' sends to c on 'dst1'. + Graph graph(OpRegistry::Global()); + Tensor a_tensor(DT_FLOAT, TensorShape({1, 1})); + a_tensor.flat()(0) = 100; + Node* a = test::graph::Constant(&graph, a_tensor); + Node* b = test::graph::Identity(&graph, a); + Node* c = test::graph::Identity(&graph, a); + + GraphDef def; + test::graph::ToGraphDef(&graph, &def); + + // The base graph have a, b, c, assigned to devices explicitly. + SetDevice(&def, a->name(), src.name()); + SetDevice(&def, b->name(), dst0.name()); + SetDevice(&def, c->name(), dst1.name()); + TF_CHECK_OK(session->Create(def)); + + // Run subgraph a -> b, and fetch b. + { + std::vector outputs; + TF_CHECK_OK(session->Run({}, {b->name()}, {}, &outputs)); + ASSERT_EQ(1, outputs.size()); + IsSingleFloatValue(outputs[0], 100); + } + + // Run subgraph a -> c, and fetch c. + { + std::vector outputs; + TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs)); + ASSERT_EQ(1, outputs.size()); + IsSingleFloatValue(outputs[0], 100); + } + + TF_CHECK_OK(session->Close()); +} + +TEST(GrpcSessionTest, Error) { + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + const string& master = cluster->targets()[0]; + const string& dev_a = cluster->devices()[0].name(); + const string& dev_b = cluster->devices()[1].name(); + LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b; + GraphDef gdef; + std::vector fetches; + { + Graph g(OpRegistry::Global()); + + // a2 = a + error(a) + // + // Subgraph for "a" fails. The master will cancel the subgraph for + // "b" and then returns the Session::Run. + auto a = test::graph::Constant(&g, Tensor()); + a->set_assigned_device_name(dev_a); + auto a_err = test::graph::Error(&g, a, "fantasia!"); + a_err->set_assigned_device_name(dev_a); + auto a2 = test::graph::Add(&g, a, a_err); + a2->set_assigned_device_name(dev_a); + fetches.push_back(a2->name()); + + // b2 = b + delay(b) + // + // Subgraph for "b" sleeps at the node "b_delay". When the sleep + // finishes, the subgraph "b" will continue execution till it + // notices that it is cancelled. Meanwhile, subgraph's executor + // and its related state (registered ops) should still be alive. + auto b = test::graph::Constant(&g, Tensor()); + b->set_assigned_device_name(dev_b); + auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000)); + b_delay->set_assigned_device_name(dev_b); + auto b2 = test::graph::Add(&g, b, b_delay); + b2->set_assigned_device_name(dev_b); + fetches.push_back(b2->name()); + test::graph::ToGraphDef(&g, &gdef); + } + std::unique_ptr session(NewRemote(Options(master, 1))); + ASSERT_TRUE(session != nullptr); + + TF_CHECK_OK(session->Create(gdef)); + { + Status status = session->Run({}, fetches, {}, nullptr); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.ToString().find("fantasia!"), string::npos); + } + // session->Close() shall clean up all states related to the session-> + // E.g., deregisters subgraph with workers, etc. + TF_CHECK_OK(session->Close()); + + // Sleep a bit so that most of asynchronous works finishes before + // the test process finishes. + Env::Default()->SleepForMicroseconds(2000000); +} + +TEST(SessionTest, SharedVar) { + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster)); + const string master = cluster->targets()[0]; + CHECK_EQ(cluster->devices().size(), 1); + + GraphDef gdef; + string init_name; + string inc_name; + string get_name; + { + Graph g(OpRegistry::Global()); + Tensor one(DT_FLOAT, TensorShape({})); + one.scalar()() = 1.0; + Node* var = test::graph::Var(&g, DT_FLOAT, one.shape()); + Node* init = test::graph::Assign(&g, var, test::graph::Constant(&g, one)); + init_name = init->name(); + Node* update = test::graph::Assign( + &g, var, test::graph::Add(&g, var, test::graph::Constant(&g, one))); + inc_name = update->name(); + get_name = var->name(); + test::graph::ToGraphDef(&g, &gdef); + } + + // Init a variable + { + Session* sess = NewRemote(Options(master, 1)); + TF_CHECK_OK(sess->Create(gdef)); + std::vector> inp; + TF_CHECK_OK(sess->Run(inp, {}, {init_name}, nullptr)); + TF_CHECK_OK(sess->Close()); + delete sess; + } + + for (int rep = 1; rep < 10; ++rep) { + // Update a variable + { + Session* sess = NewRemote(Options(master, 1)); + TF_CHECK_OK(sess->Create(gdef)); + std::vector> inp; + TF_CHECK_OK(sess->Run(inp, {}, {inc_name}, nullptr)); + TF_CHECK_OK(sess->Close()); + delete sess; + } + + // Gets the variable's value. + { + Session* sess = NewRemote(Options(master, 1)); + TF_CHECK_OK(sess->Create(gdef)); + std::vector> inp; + std::vector ret; + TF_CHECK_OK(sess->Run(inp, {get_name}, {}, &ret)); + ASSERT_EQ(ret.size(), 1); + EXPECT_EQ(ret[0].scalar()(), 1.0 * (1 + rep)); + TF_CHECK_OK(sess->Close()); + delete sess; + } + } +} + +void CreateInvalidGraph(const string& graph_def_ascii, + const string& error_substring) { + GraphDef graph; + CHECK(protobuf::TextFormat::ParseFromString(graph_def_ascii, &graph)); + + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + + std::unique_ptr session( + NewRemote(Options(cluster->targets()[0], 1))); + Status s = session->Create(graph); + + ASSERT_FALSE(s.ok()); + EXPECT_NE(s.error_message().find(error_substring), string::npos); +} + +TEST(SessionTest, InvalidOpName) { + CreateInvalidGraph(R"( + node { + name: 'a:b' op: 'Const' + attr { key: 'dtype' value { type: DT_FLOAT } } + attr { key: 'value' value { + tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] } + float_val: [100] } + } } + } + )", + "Illegal op name"); + + CreateInvalidGraph(R"( + node { + name: 'a:0' op: 'Const' + attr { key: 'dtype' value { type: DT_FLOAT } } + attr { key: 'value' value { + tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] } + float_val: [100] } + } } + } + )", + "Illegal op name"); + + CreateInvalidGraph(R"( + node { + name: '_a' op: 'Const' + attr { key: 'dtype' value { type: DT_FLOAT } } + attr { key: 'value' value { + tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] } + float_val: [100] } + } } + } + )", + "Illegal op name"); +} + +TEST(SessionTest, InvalidOpInputName) { + CreateInvalidGraph(R"( + node { + name: 'a' op: 'const' + attr { key: 'dtype' value { type: DT_FLOAT } } + attr { key: 'value' value { + tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] } + float_val: [100] } + } } + } + node { + name:'b' op:'MatMul' input:'a:first' input:'a' + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'transpose_a' value { b: false } } + attr { key: 'transpose_b' value { b: false } } + attr { key: '_kernel' value { s: 'eigen' } } + } + )", + "Illegal op input name"); + + CreateInvalidGraph(R"( + node { + name: 'a' op: 'const' + attr { key: 'dtype' value { type: DT_FLOAT } } + attr { key: 'value' value { + tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] } + float_val: [100] } + } } + } + node { + name:'b' op:'MatMul' input:'_a' input:'a' + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'transpose_a' value { b: false } } + attr { key: 'transpose_b' value { b: false } } + attr { key: '_kernel' value { s: 'eigen' } } + } + )", + "Illegal op input name"); + + CreateInvalidGraph(R"( + node { + name: 'a' op: 'const' + attr { key: 'dtype' value { type: DT_FLOAT } } + attr { key: 'value' value { + tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] } + float_val: [100] } + } } + } + node { + name:'b' op:'MatMul' input:'_a:0' input:'a' + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'transpose_a' value { b: false } } + attr { key: 'transpose_b' value { b: false } } + attr { key: '_kernel' value { s: 'eigen' } } + } + )", + "Illegal op input name"); + + CreateInvalidGraph(R"( + node { + name: 'a' op: 'const' + attr { key: 'dtype' value { type: DT_FLOAT } } + attr { key: 'value' value { + tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] } + float_val: [100] } + } } + } + node { + name:'b' op:'MatMul' input:'a:01' input:'a' + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'transpose_a' value { b: false } } + attr { key: 'transpose_b' value { b: false } } + attr { key: '_kernel' value { s: 'eigen' } } + } + )", + "Illegal op input name"); +} + +TEST(SessionTest, ExtendValidation) { + GraphDef graph; + bool success = protobuf::TextFormat::ParseFromString(R"( + node { + name: 'a' op: 'Const' + attr { key: 'dtype' value { type: DT_FLOAT } } + attr { key: 'value' value { + tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] } + float_val: [100] } + } } + } + )", + &graph); + // NOTE(mrry): CHECK not done inline to avoid a compilation error in + // open-source (due to a multi-line string in a macro argument). + ASSERT_TRUE(success); + + std::unique_ptr cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + + std::unique_ptr session( + NewRemote(Options(cluster->targets()[0], 1))); + TF_CHECK_OK(session->Create(graph)); + + // 1. Fail with an unknown input name. + GraphDef extension; + success = protobuf::TextFormat::ParseFromString(R"( + node { + name:'b' op:'MatMul' input:'a:first' input:'a' + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'transpose_a' value { b: false } } + attr { key: 'transpose_b' value { b: false } } + attr { key: '_kernel' value { s: 'eigen' } } + } + )", + &extension); + ASSERT_TRUE(success); + + Status s = session->Extend(extension); + ASSERT_FALSE(s.ok()); + EXPECT_NE(s.error_message().find("Illegal op input name"), string::npos); + + // 2. Succeed with a valid node. + success = protobuf::TextFormat::ParseFromString(R"( + node { + name:'b' op:'MatMul' input:'a' input:'a' + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'transpose_a' value { b: false } } + attr { key: 'transpose_b' value { b: false } } + attr { key: '_kernel' value { s: 'eigen' } } + } + )", + &extension); + ASSERT_TRUE(success); + TF_CHECK_OK(session->Extend(extension)); + + // 2. Fail with a duplicate node. + success = protobuf::TextFormat::ParseFromString(R"( + node { + name:'b' op:'MatMul' input:'a' input:'a' + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'transpose_a' value { b: false } } + attr { key: 'transpose_b' value { b: false } } + attr { key: '_kernel' value { s: 'eigen' } } + } + )", + &extension); + ASSERT_TRUE(success); + s = session->Extend(extension); + ASSERT_FALSE(s.ok()); + EXPECT_NE(s.error_message().find("'b', which was created by a previous call"), + string::npos); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc new file mode 100644 index 00000000000..51f24fdbf63 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc @@ -0,0 +1,98 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "external/grpc/include/grpc++/grpc++.h" +#include "external/grpc/include/grpc++/security/credentials.h" +#include "external/grpc/include/grpc++/server_builder.h" + +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/command_line_flags.h" + +// This binary starts a TensorFlow server (master and worker). +namespace tensorflow { +namespace { + +Status ParseFlagsForTask(int argc, char* argv[], GrpcServerOptions* options) { + string cluster_spec; + const bool parse_result = + ParseFlags(&argc, argv, {Flag("cluster_spec", &cluster_spec), // + Flag("job_name", &options->job_name), // + Flag("task_id", &options->task_index)}); + if (!parse_result) { + return errors::InvalidArgument("Error parsing command-line flags"); + } + + size_t my_num_tasks = 0; + for (const string& job_str : str_util::Split(cluster_spec, ',')) { + // Split each entry in the flag into 3 pieces, separated by "|". + const std::vector job_pieces = str_util::Split(job_str, '|'); + CHECK_EQ(2, job_pieces.size()) << job_str; + const string& job = job_pieces[0]; + // Does a bit more validation of the tasks_per_replica. + const StringPiece spec = job_pieces[1]; + // job_str is of form |. + const std::vector host_ports = str_util::Split(spec, ';'); + size_t num_tasks = host_ports.size(); + if (job == options->job_name) { + my_num_tasks = num_tasks; + } + TF_RETURN_IF_ERROR( + options->channel_spec.AddHostPortsJob(job, host_ports, num_tasks)); + LOG(INFO) << "Peer " << job << " " << num_tasks << " {" + << str_util::Join(host_ports, ", ") << "}"; + } + if (my_num_tasks == 0) { + return errors::InvalidArgument("Job name \"", options->job_name, + "\" does not appear in the cluster spec"); + } + if (options->task_index >= my_num_tasks) { + return errors::InvalidArgument("Task index ", options->task_index, + " is invalid (job \"", options->job_name, + "\" contains ", my_num_tasks, " tasks"); + } + return Status::OK(); +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char* argv[]) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + tensorflow::GrpcServerOptions options; + tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &options); + if (!s.ok()) { + std::cerr << "ERROR: " << s.error_message() << std::endl; + std::cerr << "Usage: " << argv[0] + << " --cluster_spec=SPEC --job_name=NAME --task_id=ID" + << std::endl; + std::cerr << "Where:" << std::endl; + std::cerr << " SPEC is (,)*" << std::endl; + std::cerr << " JOB is |(;)*" << std::endl; + std::cerr << " NAME is a valid job name ([a-z][0-9a-z]*)" << std::endl; + std::cerr << " HOST is a hostname or IP address" << std::endl; + std::cerr << " PORT is a port number" << std::endl; + return -1; + } + tensorflow::StartTensorFlowServer(options); +} diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc new file mode 100644 index 00000000000..ee5973c83a1 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc @@ -0,0 +1,123 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" + +#include "external/grpc/include/grpc++/grpc++.h" +#include "external/grpc/include/grpc++/security/credentials.h" +#include "external/grpc/include/grpc++/server_builder.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/distributed_runtime/graph_mgr.h" +#include "tensorflow/core/distributed_runtime/master_env.h" +#include "tensorflow/core/distributed_runtime/master_session.h" +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/command_line_flags.h" + +// This binary starts a TensorFlow server (master and worker) for test purposes. +namespace tensorflow { + +struct GrpcTaskOptions { + // This process belongs to the "job_name". + string job_name; + + // This process is the task-th task within the replica. 0th, 1st, + // 2nd, etc. + int32 task = 0; + + // Specification of peers. + GrpcChannelSpec channel_spec; + + SessionOptions default_session_options; +}; + +Status StartTensorFlowServer(const TaskOptions& task_options) { + thread::ThreadPool* thread_pool = + new thread::ThreadPool(Env::Default(), "server", 1); + thread_pool->Schedule([argc, argv, task_options]() { + // This process provides both the worker service and the master + // service. We let these two services share the same channel cache + // (rpc connections) and cpu devices (used by the master as the + // client device). These client devices require a worker service + // so that remote devices can copy the feeds from the client + // device in the master. + tensorflow::MasterEnv master_env; + string name_prefix = + strings::StrCat("/job:", task_optionss.job_name, "/replica:0", "/task:", + task_options.task); + DeviceFactory::AddDevices(task_options.default_session_options, name_prefix, + &master_env.local_devices); + + // Create the DeviceMgr before initializing the RPC layer, because that + // needs to know how many devices of each kind exist. + WorkerEnv worker_env; + worker_env.device_mgr = new DeviceMgr(master_env.local_devices); + + // Finish setting up Env for Worker service. + string donotcare; + CHECK(DeviceNameUtils::SplitDeviceName(master_env.local_devices[0]->name(), + &worker_env.worker_name, + &donotcare)); + worker_env.env = Env::Default(); + + GrpcChannelCache* channel_cache = + NewGrpcChannelCache(task_options.channel_spec); + string server_address = channel_cache->TranslateTask(name_prefix); + worker_env.worker_cache = NewGrpcWorkerCache(channel_cache); + worker_env.graph_mgr = new GraphMgr(&worker_env); + worker_env.rendezvous_mgr = new RpcRendezvousMgr(&worker_env); + worker_env.compute_pool = ComputePool(task_options.default_session_options); + + // Finish setting up Env for Master service. + master_env.env = Env::Default(); + master_env.ops = OpRegistry::Global(); + master_env.worker_cache = worker_env.worker_cache; + master_env.master_session_factory = internal::NewMasterSession; + + ::grpc::ServerBuilder builder; + builder.AddListeningPort(server_address, + ::grpc::InsecureServerCredentials()); + auto master_service = NewGrpcMasterService(&master_env, &builder); + auto worker_service = NewGrpcWorkerService(&worker_env, &builder); + // Finally assemble the server. + auto server_ = builder.BuildAndStart(); + + std::unique_ptr master_thread(Env::Default()->StartThread( + ThreadOptions(), "master_service_thread", + [master_service]() { master_service->HandleRPCsLoop(); })); + + std::unique_ptr worker_thread(Env::Default()->StartThread( + ThreadOptions(), "worker_service_thread", + [worker_service]() { worker_service->HandleRPCsLoop(); })); + }); + + // The ThreadPool destructor waits until all work is done (i.e. forever). + delete thread_pool; + return Status::OK(); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc new file mode 100644 index 00000000000..85b3dae56f3 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc @@ -0,0 +1,84 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" + +#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace test { + +Status TestCluster::MakeTestCluster(const SessionOptions& options, int n, + std::unique_ptr* out_cluster) { + CHECK_GE(n, 1); + std::unique_ptr ret(new TestCluster); + + ret->targets_.resize(n); + + std::vector port(n); + for (int i = 0; i < n; ++i) { + port[i] = testing::PickUnusedPortOrDie(); + ret->targets_[i] = strings::StrCat("localhost:", port[i]); + } + + const string tf_jobs = strings::StrCat("--tf_jobs=localhost|", + str_util::Join(ret->targets_, ";")); + + int num_cpus = 1; + int num_gpus = 0; + auto iter = options.config.device_count().find("CPU"); + if (iter != options.config.device_count().end()) { + num_cpus = iter->second; + } + iter = options.config.device_count().find("GPU"); + if (iter != options.config.device_count().end()) { + num_gpus = iter->second; + } + + for (int i = 0; i < n; ++i) { + const std::vector argv( + {strings::StrCat(testing::TensorFlowSrcRoot(), + "/core/distributed_runtime/rpc/grpc_testlib_server"), + /* see grpc_testlib_server.cc for flags */ + tf_jobs, "--tf_job=localhost", strings::StrCat("--tf_task=", i), + strings::StrCat("--num_cpus=", num_cpus), + strings::StrCat("--num_gpus=", num_gpus)}); + ret->subprocesses_.emplace_back(testing::CreateSubProcess(argv)); + bool success = ret->subprocesses_[i]->Start(); + if (!success) { + return errors::Internal("Could not start subprocess"); + } + } + + SessionOptions options_copy(options); + options_copy.target = strings::StrCat("grpc://", ret->targets_[0]); + std::unique_ptr session(new GrpcSession(options_copy)); + std::vector device_attributes; + ret->devices_ = session->ListDevices(); + + *out_cluster = std::move(ret); + return Status::OK(); +} + +TestCluster::~TestCluster() { + for (auto& subprocess : subprocesses_) { + subprocess->Kill(9); + } +} + +} // end namespace test +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h new file mode 100644 index 00000000000..7460c1c9b44 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h @@ -0,0 +1,73 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class Device; + +namespace test { + +// Provides a handle to a set of TensorFlow servers (masters and +// workers) for testing purposes. +// +// This class currently runs the servers in separate processes; the +// lifetime of this object is coterminous with the lifetimes of those +// processes. +class TestCluster { + public: + // Creates a new test cluster based on the given `options` (which + // configure the number of devices of each type) and a count of + // processes `n`. On success, the test cluster is stored in + // *out_cluster, and this function returns OK. Otherwise an error is + // returned. + static Status MakeTestCluster(const SessionOptions& options, int n, + std::unique_ptr* out_cluster); + ~TestCluster(); + + // Returns a vector of string ":" pairs that may be + // used as targets to construct a GrpcSession. + const std::vector& targets() const { return targets_; } + + // Returns a vector of devices available in this test cluster. + const std::vector& devices() const { return devices_; } + + private: + TestCluster() = default; + + std::vector> subprocesses_; + std::vector targets_; + std::vector devices_; + + TF_DISALLOW_COPY_AND_ASSIGN(TestCluster); +}; + +} // end namespace test +} // end namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc new file mode 100644 index 00000000000..e2518f8fced --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc @@ -0,0 +1,91 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace test { + +// ErrorOp::Compute returns an error. +REGISTER_OP("Error") + .Input("in: T") + .Output("out: T") + .Attr("T: type") + .Attr("message: string"); +class ErrorOp : public OpKernel { + public: + explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &errmsg_)); + } + + void Compute(OpKernelContext* ctx) override { + ctx->SetStatus(errors::Internal(errmsg_)); + } + + private: + string errmsg_; +}; +REGISTER_KERNEL_BUILDER(Name("Error").Device(DEVICE_CPU), ErrorOp); + +REGISTER_OP("InvalidRefType") + .Output("out: Ref(TIn)") + .Attr("TIn: type") + .Attr("TOut: type"); +class InvalidRefType : public OpKernel { + public: + explicit InvalidRefType(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("TOut", &dtout_)); + output_ = Tensor(dtout_, TensorShape({})); + } + + void Compute(OpKernelContext* ctx) override { + ctx->set_output_ref(0, &mu_, &output_); + } + + private: + DataType dtout_; + mutex mu_; + Tensor output_; +}; +REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU), + InvalidRefType); + +// DelayOp::AsyncCompute sleeps for "micros"-econd and then returns +// its input. +REGISTER_OP("Delay") + .Input("in: T") + .Output("out: T") + .Attr("T: type") + .Attr("micros: int"); +class DelayOp : public AsyncOpKernel { + public: + explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("micros", µs_)); + } + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + ctx->set_output(0, ctx->input(0)); + ctx->env()->SchedClosureAfter(micros_, done); + } + + private: + int64 micros_; +}; +REGISTER_KERNEL_BUILDER(Name("Delay").Device(DEVICE_CPU), DelayOp); + +} // namespace test +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc new file mode 100644 index 00000000000..62c88daa174 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc @@ -0,0 +1,92 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "external/grpc/include/grpc++/grpc++.h" +#include "external/grpc/include/grpc++/security/credentials.h" +#include "external/grpc/include/grpc++/server_builder.h" + +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/command_line_flags.h" + +// This binary starts a TensorFlow server (master and worker) for test purposes. +namespace tensorflow { +namespace { + +Status ParseFlagsForTask(int argc, char* argv[], GrpcServerOptions* options) { + string job_spec; + int num_cpus = 1; + int num_gpus = 0; + const bool parse_result = + ParseFlags(&argc, argv, {Flag("tf_jobs", &job_spec), // + Flag("tf_job", &options->job_name), // + Flag("tf_task", &options->task_index), // + Flag("num_cpus", &num_cpus), // + Flag("num_gpus", &num_gpus)}); + if (!parse_result) { + return errors::InvalidArgument("Error parsing command-line flags"); + } + + uint32 my_tasks_per_replica = 0; + for (const string& job_str : str_util::Split(job_spec, ',')) { + // Split each entry in the flag into 3 pieces, separated by "|". + const std::vector job_pieces = str_util::Split(job_str, '|'); + CHECK_EQ(2, job_pieces.size()) << job_str; + const string& job = job_pieces[0]; + // Does a bit more validation of the tasks_per_replica. + const StringPiece spec = job_pieces[1]; + // job_str is of form |. + const std::vector host_ports = str_util::Split(spec, ';'); + uint32 tasks_per_replica = host_ports.size(); + if (job == options->job_name) { + my_tasks_per_replica = tasks_per_replica; + } + TF_RETURN_IF_ERROR(options->channel_spec.AddHostPortsJob( + job, host_ports, tasks_per_replica)); + LOG(INFO) << "Peer " << job << " " << tasks_per_replica << " {" + << str_util::Join(host_ports, ", ") << "}"; + } + if (my_tasks_per_replica == 0) { + return errors::InvalidArgument("Invalid job specification"); + } + + (*options->default_session_options.config.mutable_device_count())["CPU"] = + num_cpus; + (*options->default_session_options.config.mutable_device_count())["GPU"] = + num_gpus; + return Status::OK(); +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char* argv[]) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + tensorflow::GrpcServerOptions options; + tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &options); + if (!s.ok()) { + LOG(ERROR) << "Could not parse flags: " << s.error_message(); + return -1; + } + tensorflow::StartTensorFlowServer(options); + // NOTE(mrry): Unreachable code. + return 0; +} diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/tensorflow/core/distributed_runtime/rpc/grpc_util.h new file mode 100644 index 00000000000..fc4b699e2a2 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.h @@ -0,0 +1,48 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ + +#include + +#include "external/grpc/include/grpc++/grpc++.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +inline Status FromGrpcStatus(const ::grpc::Status& s) { + if (s.ok()) { + return Status::OK(); + } else { + return Status(static_cast(s.error_code()), + s.error_message()); + } +} + +inline ::grpc::Status ToGrpcStatus(const ::tensorflow::Status& s) { + if (s.ok()) { + return ::grpc::Status::OK; + } else { + return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()), + s.error_message()); + } +} + +typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc new file mode 100644 index 00000000000..8658b2f31e3 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -0,0 +1,85 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" + +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h" +#include "tensorflow/core/distributed_runtime/worker_cache_logger.h" +#include "tensorflow/core/distributed_runtime/worker_cache_partial.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +class GrpcWorkerCache : public WorkerCachePartial { + public: + explicit GrpcWorkerCache(GrpcChannelCache* channel_cache) + : channel_cache_(channel_cache) { + // TODO(mrry): Investigate possible performance improvements by + // replacing this thread with a threadpool. + polling_thread_ = Env::Default()->StartThread( + ThreadOptions(), "grpc_worker_cache", [this]() { + void* tag; + bool ok; + while (completion_queue_.Next(&tag, &ok)) { + GrpcClientCQTag* callback_tag = static_cast(tag); + callback_tag->OnCompleted(ok); + delete callback_tag; + } + }); + } + + // Explicit destructor to control destruction order. + ~GrpcWorkerCache() override { + completion_queue_.Shutdown(); + delete polling_thread_; // Blocks until thread exits. + delete channel_cache_; + } + + void ListWorkers(std::vector* workers) override { + channel_cache_->ListWorkers(workers); + } + + WorkerInterface* CreateWorker(const string& target) override { + SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target); + CHECK(channel) << "Channel was null"; + if (!channel) return nullptr; + WorkerInterface* ret = + NewGrpcRemoteWorker(channel, &completion_queue_, &logger_); + return ret; + } + + void SetLogging(bool v) override { logger_.SetLogging(v); } + + void ClearLogs() override { logger_.ClearLogs(); } + + bool RetrieveLogs(int64 step_id, StepStats* ss) override { + return logger_.RetrieveLogs(step_id, ss); + } + + private: + GrpcChannelCache* channel_cache_; // Owned. + ::grpc::CompletionQueue completion_queue_; + Thread* polling_thread_; // Owned. + WorkerCacheLogger logger_; +}; + +WorkerCacheInterface* NewGrpcWorkerCache(GrpcChannelCache* cc) { + return new GrpcWorkerCache(cc); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h new file mode 100644 index 00000000000..9332d389223 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h @@ -0,0 +1,28 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" + +namespace tensorflow { + +// The returned WorkerCacheInterface object takes the ownership of "cc". +WorkerCacheInterface* NewGrpcWorkerCache(GrpcChannelCache* cc); + +} // namespace tensorflow +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc new file mode 100644 index 00000000000..ed69f4beb9c --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -0,0 +1,415 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" + +#include + +#include "external/grpc/include/grpc++/server_builder.h" + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/gpu/gpu_util.h" +#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/common_runtime/step_stats_collector.h" +#include "tensorflow/core/distributed_runtime/graph_mgr.h" +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/protobuf/worker.pb.h" +#include "tensorflow/core/protobuf/worker_service.grpc.pb.h" +#include "tensorflow/core/protobuf/worker_service.pb.h" + +namespace tensorflow { + +namespace { + +static Tensor empty_tensor(DT_FLOAT); + +class GrpcWorkerService : public AsyncServiceInterface { + public: + GrpcWorkerService(WorkerEnv* env, ::grpc::ServerBuilder* builder) + : env_(env), cancellation_manager_(new CancellationManager) { + builder->RegisterService(&worker_service_); + cq_ = builder->AddCompletionQueue().release(); + } + + ~GrpcWorkerService() { delete cq_; } + +// This macro creates a new request for the given RPC method name +// (e.g., `ENQUEUE_REQUEST(GetStatus);`), and enqueues it on +// `this->cq_`. +// +// This macro is invoked one or more times for each RPC method to +// ensure that there are sufficient completion queue entries to +// handle incoming requests without blocking. +// +// The implementation of the request handler for each RPC method +// must ensure that it calls ENQUEUE_REQUEST() for that RPC method, +// to keep accepting new requests. +#define ENQUEUE_REQUEST(method) \ + do { \ + Call:: \ + EnqueueRequest(&worker_service_, cq_, \ + &grpc::WorkerService::AsyncService::Request##method, \ + &GrpcWorkerService::method##Handler); \ + } while (0) + + // This method blocks forever handling requests from the completion queue. + void HandleRPCsLoop() { + // TODO(mrry): This may require performance engineering. We can + // add more threads to service the completion queue, and add more + // of various request types if they are short and frequent. + // Currently we allow unbounded numbers of pending calls for each + // method, by re-enqueuing a request before the previous one + // completes, and we may decide to bound some of the request + // types. + ENQUEUE_REQUEST(GetStatus); + ENQUEUE_REQUEST(CleanupAll); + ENQUEUE_REQUEST(RegisterGraph); + ENQUEUE_REQUEST(DeregisterGraph); + + // TODO(mrry): Consider enqueuing more of these request types. + ENQUEUE_REQUEST(RecvTensor); + ENQUEUE_REQUEST(RunGraph); + + ENQUEUE_REQUEST(CleanupGraph); + ENQUEUE_REQUEST(Logging); + ENQUEUE_REQUEST(Tracing); + + void* tag; + bool ok; + while (cq_->Next(&tag, &ok)) { + UntypedCall::Tag* callback_tag = + static_cast::Tag*>(tag); + callback_tag->OnCompleted(this, ok); + delete callback_tag; + } + } + + private: + WorkerEnv* env_; // Not owned. + ::grpc::ServerCompletionQueue* cq_; // Owned. + + grpc::WorkerService::AsyncService worker_service_; + + mutex mu_; + CancellationManager* cancellation_manager_ GUARDED_BY(mu_); + + // The following section contains one request handler method per + // RPC. The The `FooHandler` method is called (indirectly) by + // `HandleRPCsLoop()` when the next Foo RPC is received. Each + // `FooHandler` call schedules a closure on `env_->compute_pool`, + // and is responsible for requesting the next Foo call by calling + // `ENQUEUE_REQUEST(Foo)`. + + template + using WorkerCall = Call; + + void GetStatusHandler(WorkerCall* call) { + env_->compute_pool->Schedule([this, call]() { + DeviceMgr* dm = env_->device_mgr; + std::vector devices; + dm->ListDeviceAttributes(&devices); + call->response.mutable_device_attributes()->Reserve(devices.size()); + for (size_t i = 0; i < devices.size(); i++) { + call->response.add_device_attributes()->Swap(&devices[i]); + } + call->SendResponse(::grpc::Status::OK); + }); + ENQUEUE_REQUEST(GetStatus); + } + + void CleanupAllHandler( + WorkerCall* call) { + env_->compute_pool->Schedule([this, call]() { + std::vector containers; + for (const auto& c : call->request.container()) containers.push_back(c); + env_->device_mgr->ClearContainers(containers); + call->SendResponse(::grpc::Status::OK); + }); + ENQUEUE_REQUEST(CleanupAll); + } + + void RegisterGraphHandler( + WorkerCall* call) { + env_->compute_pool->Schedule([this, call]() { + Status s = env_->graph_mgr->Register( + call->request.session_handle(), call->request.graph_def(), + call->request.graph_options(), call->response.mutable_graph_handle()); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(RegisterGraph); + } + + void DeregisterGraphHandler( + WorkerCall* call) { + env_->compute_pool->Schedule([this, call]() { + Status s = env_->graph_mgr->Deregister(call->request.graph_handle()); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(DeregisterGraph); + } + + void RunGraphHandler(WorkerCall* call) { + env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); }); + ENQUEUE_REQUEST(RunGraph); + } + + void RecvTensorHandler( + WorkerCall* call) { + env_->compute_pool->Schedule([this, call]() { DoRecvTensor(call); }); + ENQUEUE_REQUEST(RecvTensor); + } + + void CleanupGraphHandler( + WorkerCall* call) { + env_->compute_pool->Schedule([this, call]() { + const int64 step_id = call->request.step_id(); + env_->rendezvous_mgr->Cleanup(step_id); + call->SendResponse(::grpc::Status::OK); + }); + ENQUEUE_REQUEST(CleanupGraph); + } + + void LoggingHandler(WorkerCall* call) { + env_->compute_pool->Schedule([this, call]() { + Status s = DoLogging(call); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(Logging); + } + + void TracingHandler(WorkerCall* call) { + SchedClosure([this, call]() { + Status s = DoTracing(call); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(Tracing); + } +#undef ENQUEUE_REQUEST + + private: + // The following section contains the implementation of RunGraph() + // RecvTensor(), Logging(), and Tracing(), which are the four + // non-trivial and potentially long-running RPCs performed by a + // TensorFlow worker. + + void AbortStep(int64 step_id) { + Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id); + SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { + // Delay a bit before aborting the step. This way, the root + // cause may return first back to the client instead of this + // cancellation generated abort error. + rendez->StartAbort(errors::Aborted("Step ", step_id)); + rendez->Unref(); + }); + } + + Status PrepareRunGraph(const RunGraphRequest& req, GraphMgr::NamedTensors* in, + GraphMgr::NamedTensors* out) { + if (req.send_size() > 0) { + // TODO(zhifengc): Let the caller decide on which device to + // allocate the tensor. + Device* cpu_dev = nullptr; + TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice("CPU:0", &cpu_dev)); + AllocatorAttributes alloc_attrs; + Tensor val; + for (const NamedTensor& entry : req.send()) { + TF_RETURN_IF_ERROR( + cpu_dev->MakeTensorFromProto(entry.val(), alloc_attrs, &val)); + in->insert({entry.key(), val}); + } + } + for (const string& key : req.recv_key()) { + out->insert({key, empty_tensor}); + } + return Status::OK(); + } + + void DoRunGraph(WorkerCall* call) { + const int64 step_id = call->request.step_id(); + TRACEPRINTF("RunGraph: %lld", step_id); + GraphMgr::NamedTensors in; + GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; + Status s = PrepareRunGraph(call->request, &in, out); + if (!s.ok()) { + delete out; + call->SendResponse(ToGrpcStatus(s)); + return; + } + StepStatsCollector* collector = nullptr; + // TODO(mrry): Collect results from a profiler if available. + CancellationManager* cm = new CancellationManager; + call->SetCancelCallback([this, cm, step_id]() { + cm->StartCancel(); + AbortStep(step_id); + }); + CancellationToken token; + { + mutex_lock l(mu_); + token = cancellation_manager_->get_cancellation_token(); + cancellation_manager_->RegisterCallback(token, + [cm]() { cm->StartCancel(); }); + } + env_->graph_mgr->ExecuteAsync( + call->request.graph_handle(), step_id, call->request.exec_opts(), + collector, cm, in, out, [this, call, cm, out, token](Status s) { + call->ClearCancelCallback(); + { + mutex_lock l(mu_); + cancellation_manager_->DeregisterCallback(token); + } + delete cm; + + if (s.ok()) { + for (const auto& p : *out) { + const string& key = p.first; + const Tensor& val = p.second; + auto* recv = call->response.add_recv(); + recv->set_key(key); + // TODO(zhifengc): Deal with gpu -> cpu copy. + TensorProto* proto = recv->mutable_val(); + val.AsProtoField(proto); + } + } + delete out; + call->SendResponse(ToGrpcStatus(s)); + }); + } + + // Helper for RecvTensor. Validates "key" and returns the source + // device in "*src_dev". + Status PrepareRecvTensor(const string& key, Device** src_dev) { + // Validate the key. + Rendezvous::ParsedKey parsed; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed)); + + // Figures out which device the tensor is hosted on. + TF_RETURN_IF_ERROR( + env_->device_mgr->LookupDevice(parsed.src_device, src_dev)); + + // Does the device have the right incarnation number we expect? + if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { + return errors::Aborted( + "RecvTensor expects a different device incarnation: ", + parsed.src_incarnation, " vs. ", + (*src_dev)->attributes().incarnation(), + ". Your worker job was probably restarted. Check your " + "worker job for the reason why it was restarted."); + } + + return Status::OK(); + } + + void DoRecvTensor(WorkerCall* call) { + const int64 step_id = call->request.step_id(); + const string& key = call->request.rendezvous_key(); + TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); + Device* src_dev = nullptr; + Status s = PrepareRecvTensor(key, &src_dev); + if (!s.ok()) { + call->SendResponse(ToGrpcStatus(s)); + return; + } + + // Request the tensor associated with the rendezvous key. Any time + // while waiting for the tensor to be produced, up until the start + // of execution of the callback lambda body below, an RPC + // cancellation should abort the rendezvous. + call->SetCancelCallback([this, step_id]() { AbortStep(step_id); }); + env_->rendezvous_mgr->RecvLocalAsync( + step_id, key, + [this, call, src_dev](const Status& status, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& val, const bool is_dead) { + call->ClearCancelCallback(); + Status s = status; + if (s.ok()) { + // DMA can only be used for Tensors that do not fall into + // the following three odd edge cases: 1) a zero-size + // buffer, 2) a dead tensor which has an uninit value, and + // 3) the tensor has the on_host allocation attribute, + // i.e. it's in CPU RAM *independent of its assigned + // device type*. + // const size_t bytes = is_dead ? 0 : val.TotalBytes(); + const bool on_host = send_args.alloc_attrs.on_host(); + const DeviceContext* send_dev_context = send_args.device_context; + call->response.set_is_dead(is_dead); + StatusCallback response_ready = [call](const Status& s) { + // The value is now ready to be returned on the wire. + call->response.set_send_start_micros(Env::Default()->NowMicros()); + call->SendResponse(ToGrpcStatus(s)); + }; + { + // Non-DMA cases. + if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { + CHECK(send_dev_context) + << "send dev name: " << src_dev->name() + << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); + // "val" is on a GPU. Uses GPUUtil to fill the response proto. + GPUUtil::SetProtoFromGPU(val, src_dev, send_dev_context, + call->response.mutable_tensor(), + is_dead, response_ready); + } else { + // "val" is in CPU memory. + TensorProto* proto = call->response.mutable_tensor(); + val.AsProtoTensorContent(proto); + response_ready(Status::OK()); + } + } + } else { + // !s.ok() + call->SendResponse(ToGrpcStatus(s)); + } + }); + } + + Status DoLogging(WorkerCall* call) { + // TODO(mrry): Platform-specific tracing support. + return errors::Unimplemented("Logging"); + } + + Status DoTracing(WorkerCall* call) { + // TODO(mrry): Platform-specific tracing support. + return errors::Unimplemented("Tracing"); + } + + TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService); +}; + +} // namespace + +AsyncServiceInterface* NewGrpcWorkerService(WorkerEnv* env, + ::grpc::ServerBuilder* builder) { + return new GrpcWorkerService(env, builder); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h new file mode 100644 index 00000000000..4b46aed835b --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -0,0 +1,34 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ + +namespace grpc { +class ServerBuilder; +} // namespace grpc + +namespace tensorflow { + +class AsyncServiceInterface; +class WorkerEnv; + +// Returns an implementation of WorkerService rpc service. +AsyncServiceInterface* NewGrpcWorkerService(WorkerEnv* env, + ::grpc::ServerBuilder* builder); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc new file mode 100644 index 00000000000..6b69e37e159 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -0,0 +1,196 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { + +class RpcRemoteRendezvous : public BaseRemoteRendezvous { + public: + RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id) + : BaseRemoteRendezvous(env, step_id, false) {} + + protected: + void RecvFromRemoteAsync(const string& key, + const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& args, + DoneCallback done) override; + + private: + ~RpcRemoteRendezvous() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous); +}; + +// Used only to retrieve tensors from remote processes. +class RpcRecvTensorCall : public BaseRecvTensorCall { + public: + RpcRecvTensorCall(WorkerCacheInterface* wc, WorkerInterface* wi, + int64 step_id, const string& key, + const string& remote_dev, Allocator* allocator, + Device* dst_device) + : wi_(wi), + wc_(wc), + remote_dev_(remote_dev), + allocator_(allocator), + dst_(dst_device) { + req_.set_step_id(step_id); + req_.set_rendezvous_key(key); + } + + ~RpcRecvTensorCall() override { delete wi_; } + + void Start(std::function recv_done) override { + StartRTCall(recv_done); + } + + void StartAbort(const Status& s) override { + { + mutex_lock l(mu_); + status_.Update(s); + } + opts_.StartCancel(); + } + + Status status() const override { + mutex_lock l(mu_); + return status_; + } + + const TensorProto& tensor_proto() const { return resp_.tensor(); } + + const RecvTensorResponse& response() const { return resp_; } + + bool is_dead() const { return resp_.is_dead(); } + + private: + // Start the main RecvTensor call, checking for an async abort. + void StartRTCall(std::function recv_done) { + wi_->RecvTensorAsync(&opts_, &req_, &resp_, + nullptr /* TensorBufAllocator */, + // done callback + [this, recv_done](const Status& s) { + { + mutex_lock l(mu_); + status_.Update(s); + } + recv_done(); + }); + } + + WorkerInterface* wi_; // Owned. + WorkerCacheInterface* wc_; // Not owned. + string remote_dev_; + Allocator* allocator_; + Device* dst_; + CallOptions opts_; + RecvTensorRequest req_; + RecvTensorResponse resp_; + + mutable mutex mu_; + Status status_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall); +}; + + +void RpcRemoteRendezvous::RecvFromRemoteAsync( + const string& key, const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& recv_args, DoneCallback done) { + Status s; + + // key.src_device identifies a remote device. + string src_worker; + string src_rel_device; + if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker, + &src_rel_device)) { + s = errors::Internal(parsed.src_device, + " is invalid remote source device."); + } + WorkerCacheInterface* worker_cache = env_->worker_cache; + if (s.ok() && worker_cache == nullptr) { + s = errors::Internal("No remote worker cache available."); + } + WorkerInterface* rwi = env_->worker_cache->CreateWorker(src_worker); + if (s.ok() && rwi == nullptr) { + s = errors::Internal("No worker known as ", src_worker); + } + + Device* dst_device; + if (s.ok()) { + s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + } + if (!s.ok()) { + done(s, Args(), recv_args, Tensor{}, false); + return; + } + Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs); + + // Prepare a RecvTensor call that can handle being aborted. + RpcRecvTensorCall* call = + new RpcRecvTensorCall(worker_cache, rwi, step_id_, key, + parsed.src_device, allocator, dst_device); + + // Record "call" in active_ so that it can be aborted cleanly. + RegisterCall(call); + + // Start "call". + call->Start([this, call, parsed, recv_args, done]() { + // Removes "call" from active_. Prevent StartAbort(). + DeregisterCall(call); + // If StartAbort was called prior to DeregisterCall, then the + // current status should be bad. + Status s = call->status(); + Tensor val; + if (s.ok()) { + Device* dst_device; + s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + if (s.ok()) { + s = dst_device->MakeTensorFromProto(call->tensor_proto(), + recv_args.alloc_attrs, &val); + } + } + done(s, Args(), recv_args, val, call->is_dead()); + delete call; + }); +} + +} // namespace + +BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id, + const WorkerEnv* worker_env) { + return new RpcRemoteRendezvous(worker_env, step_id); +} + + +} // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h new file mode 100644 index 00000000000..65b21b425cd --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h @@ -0,0 +1,57 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ + +#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// RendezvousMgr keeps track of a set of local rendezvous instances. +// All tensors sent by this worker are buffered in a RendezvousMgr +// until the tensor is received. Each global unique "step_id" +// corresponds to one local rendezvous instance managed by a +// RendezvousMgr. +// +// E.g., +// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); +// fork execution of an graph executor using "rendez" on thread 1; +// fork execution of another graph executor using "rendez" on thread 2; +// ... +// join threads 1 and 2; +// +// In the example above, execution in thread 1 and 2 communicates with +// each other by send/recv operations through the "rend". +// +// Tensors sent and recved through rendezvous managed by this +// RendezvousMgr must have keys generated by Rendezvous::CreateKey. +class RpcRendezvousMgr : public BaseRendezvousMgr { + public: + explicit RpcRendezvousMgr(const WorkerEnv* env) : BaseRendezvousMgr(env) {} + + protected: + BaseRemoteRendezvous* Create(int64 step_id, + const WorkerEnv* worker_env) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc new file mode 100644 index 00000000000..0f855e8f28d --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -0,0 +1,172 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" + +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +// string -> Tensor +Tensor V(const string& content) { + Tensor tensor(DT_STRING, TensorShape({})); + tensor.scalar()() = content; + return tensor; +} + +// Tensor -> string +string V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_STRING); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar()(); +} + +TEST(RpcRendezvousMgrTest, LocalSendRecv) { + WorkerEnv env; + env.env = Env::Default(); + env.worker_name = "/job:mnist/replica:1/task:2"; + RpcRendezvousMgr rmgr(&env); + const int64 step_id = 123; + const string key = Rendezvous::CreateKey( + "/job:mnist/replica:1/task:2/cpu:0", 7890, + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)); + { + Rendezvous* rendez = rmgr.Find(step_id); + core::ScopedUnref unref(rendez); + Rendezvous::Args args; + TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); + } + { + Tensor val(DT_FLOAT); + bool val_dead = false; + TF_ASSERT_OK(rmgr.RecvLocal(step_id, key, &val, &val_dead)); + EXPECT_EQ(V(val), "peach"); + } + rmgr.Cleanup(step_id); +} + +TEST(RpcRendezvousMgrTest, LocalAbort) { + WorkerEnv env; + env.env = Env::Default(); + env.worker_name = "/job:mnist/replica:1/task:2"; + RpcRendezvousMgr rmgr(&env); + const string key = Rendezvous::CreateKey( + "/job:mnist/replica:1/task:2/cpu:0", 7890, + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)); + { // Explicit Abort(). + const int64 step_id = 123; + Rendezvous* rendez = rmgr.Find(step_id); + core::ScopedUnref unref(rendez); + SchedClosure([env, rendez]() { + env.env->SleepForMicroseconds(100 * 1000); + rendez->StartAbort(errors::Aborted("")); + }); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); + } + { // Cleanup causes Abort(). + const int64 step_id = 321; + Rendezvous* rendez = rmgr.Find(step_id); + core::ScopedUnref unref(rendez); + SchedClosure([env, &rmgr, step_id]() { + env.env->SleepForMicroseconds(100 * 1000); + rmgr.Cleanup(step_id); + }); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); + } +} + +TEST(RpcRendezvousMgrTest, CleanupAll) { + WorkerEnv env; + env.env = Env::Default(); + env.worker_name = "/job:mnist/replica:1/task:2"; + RpcRendezvousMgr rmgr(&env); + const string key = Rendezvous::CreateKey( + "/job:mnist/replica:1/task:2/cpu:0", 7890, + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)); + { + const int64 step_id = 123; + Rendezvous* rendez = rmgr.Find(step_id); + core::ScopedUnref unref(rendez); + Rendezvous::Args args; + TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); + rmgr.CleanupAll(); + Tensor val(DT_STRING); + bool val_dead = false; + EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); + } +} + +class DummyDeviceContext : public DeviceContext { + public: + explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {} + ~DummyDeviceContext() override {} + int stream_id() const { return stream_id_; } + + private: + const int stream_id_; +}; + +TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) { + DummyDeviceContext* dc = new DummyDeviceContext(123); + + WorkerEnv env; + env.env = Env::Default(); + env.worker_name = "/job:mnist/replica:1/task:2"; + RpcRendezvousMgr rmgr(&env); + const int64 step_id = 123; + const string key = Rendezvous::CreateKey( + "/job:mnist/replica:1/task:2/cpu:0", 7890, + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)); + { + Rendezvous* rendez = rmgr.Find(step_id); + core::ScopedUnref unref(rendez); + Rendezvous::Args args; + args.device_context = dc; + TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); + } + { + Notification n; + rmgr.RecvLocalAsync( + step_id, key, [&n](const Status& s, const Rendezvous::Args send_args, + const Rendezvous::Args recv_args, const Tensor& val, + bool is_dead) { + auto send_dev_context = + static_cast(send_args.device_context); + CHECK_EQ(123, send_dev_context->stream_id()); + CHECK_EQ(V(val), "peach"); + n.Notify(); + }); + n.WaitForNotification(); + } + rmgr.Cleanup(step_id); + dc->Unref(); +} + +// NOTE: Remote Send/Recv is better tested in worker_test.cc + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/simple_graph_execution_state.cc b/tensorflow/core/distributed_runtime/simple_graph_execution_state.cc new file mode 100644 index 00000000000..94714f47098 --- /dev/null +++ b/tensorflow/core/distributed_runtime/simple_graph_execution_state.cc @@ -0,0 +1,309 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/simple_graph_execution_state.h" + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/simple_placer.h" +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/dot.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/worker.pb.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +string BuildGraphOptions::DebugString() const { + string rv = "Feed endpoints: "; + for (auto& s : feed_endpoints) { + strings::StrAppend(&rv, s, ", "); + } + strings::StrAppend(&rv, "\nFetch endpoints: "); + for (auto& s : fetch_endpoints) { + strings::StrAppend(&rv, s, ", "); + } + strings::StrAppend(&rv, "\nTarget nodes: "); + for (auto& s : target_nodes) { + strings::StrAppend(&rv, s, ", "); + } + return rv; +} + +SimpleGraphExecutionState::SimpleGraphExecutionState( + const OpRegistryInterface* ops, + const SimpleGraphExecutionStateOptions& options) + : ops_(ops), + device_set_(options.device_set), + session_options_(options.session_options), + base_(nullptr), + placed_(nullptr) { + // TODO(mrry): Publish placement visualizations or handle the log + // placement option. +} + +SimpleGraphExecutionState::~SimpleGraphExecutionState() { + mutex_lock l(mu_); + delete base_; + delete placed_; +} + +Status SimpleGraphExecutionState::Create(GraphDef* graph_def) { + if (original_graph_def_.node_size() > 0) { + return errors::InvalidArgument( + "Cannot call Create on SimpleGraphExecutionState twice"); + } + + original_graph_def_.Swap(graph_def); + VLOG(2) << "Incoming def: " << original_graph_def_.DebugString(); + return AddDefaultAttrsToGraphDef(&original_graph_def_, *ops_, 0); +} + +Status SimpleGraphExecutionState::Extend( + const GraphDef& extension_def, SimpleGraphExecutionState** out) const { + std::unordered_set new_names; + // 1. Build an index of the new node names. + for (const NodeDef& node : extension_def.node()) { + new_names.insert(node.name()); + } + + // 2. Add the non-duplicates from the old graph to the new graph. + // Return an error if the same node name appears in both the + // old graph and the extension. + GraphDef gdef; + for (const NodeDef& node : original_graph_def_.node()) { + if (new_names.count(node.name()) == 0) { + *gdef.add_node() = node; + } else { + return errors::InvalidArgument(tensorflow::strings::Printf( + "GraphDef argument to Extend includes node '%s', which was created " + "by a previous call to Create or Extend in this session.", + node.name().c_str())); + } + } + + int old_node_size = gdef.node_size(); + gdef.mutable_node()->MergeFrom(extension_def.node()); + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *ops_, old_node_size)); + + // 3. Add the extension. + SimpleGraphExecutionStateOptions combined_options; + combined_options.device_set = device_set_; + + SimpleGraphExecutionState* new_execution_state = + new SimpleGraphExecutionState(ops_, combined_options); + Status new_execution_state_status = new_execution_state->Create(&gdef); + if (!new_execution_state_status.ok()) { + delete new_execution_state; + return new_execution_state_status; + } + *out = new_execution_state; + + // Ensure that any state created in the precursor is accessible in the + // new graph. + { + mutex_lock l(mu_); + for (const auto& placement : stateful_placements_) { + (*out)->stateful_placements_.insert(placement); + } + } + + // TODO(mrry): This is likely to be used for non-throughput-sensitive + // interactive workloads, but in future we may want to transfer other + // parts of the placement and/or cost model. + return Status::OK(); +} + +Status SimpleGraphExecutionState::InitBaseGraph() { + std::unique_ptr new_base(new Graph(ops_)); + GraphConstructorOptions opts; + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(opts, original_graph_def_, new_base.get())); + for (const Node* n : new_base->nodes()) { + VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id(); + node_name_to_cost_id_map_[n->name()] = n->cost_id(); + } + + Status status = PreliminaryPlace(*new_base); + if (!status.ok()) { + node_name_to_cost_id_map_.clear(); + return status; + } + base_ = new_base.release(); + return Status::OK(); +} + +Status SimpleGraphExecutionState::GlobalNodeDefByName(const string& name, + NodeDef* out) { + NodeNameToCostIdMap::const_iterator iter = + node_name_to_cost_id_map_.find(name); + if (iter != node_name_to_cost_id_map_.end()) { + mutex_lock l(mu_); // could use reader lock + const Node* node = placed_->FindNodeId(iter->second); + if (node) { + *out = node->def(); + return Status::OK(); + } + } + return errors::NotFound("Node name: ", name); +} + +Status SimpleGraphExecutionState::PreliminaryPlace(const Graph& base) { + VLOG(1) << "PreliminaryPlace"; + Graph* ng = new Graph(ops_); + + CopyGraph(base, ng); + Status status = DoPlace(ng); + if (!status.ok()) { + delete ng; + } else { + delete placed_; + placed_ = ng; + FreezeStatefulNodes(true /*is_prelim*/); + } + return status; +} + +void SimpleGraphExecutionState::FreezeStatefulNodes(bool is_prelim) { + if (is_prelim) { + // During the preliminary placement every stateful Node got placed + // somewhere, and we need to remember where, so it doesn't move. + for (Node* n : placed_->nodes()) { + if (n->op_def().is_stateful()) { + stateful_placements_[n->name()] = n->assigned_device_name(); + } + } + } else { + // During later placements it's possible for new stateful nodes to + // appear. They are noticed while we're pinning the pre-existing + // stateful nodes to their prior positions, and after they've been + // placed this function is entered to record their placements. + for (Node* n : missing_stateful_placements_) { + CHECK(n->op_def().is_stateful()); + stateful_placements_[n->name()] = n->assigned_device_name(); + } + missing_stateful_placements_.clear(); + } +} + +void SimpleGraphExecutionState::PlaceStatefulNodes(Graph* graph) { + for (Node* n : graph->nodes()) { + if (n->op_def().is_stateful()) { + PlaceMap::const_iterator iter = stateful_placements_.find(n->name()); + if (iter == stateful_placements_.end()) { + // NOTE(tucker): I don't understand why this can occur. So far, + // I've only seen it in eval instances, started from a checkpoint. + missing_stateful_placements_.push_back(n); + } else { + n->set_assigned_device_name(iter->second); + } + } + } +} + +Status SimpleGraphExecutionState::DoPlace(Graph* graph) { + Status status; + // TODO(mrry): Port other placement algorithms from whitepaper. + return SimplePlacement(graph); +} + +Status SimpleGraphExecutionState::BuildGraph(const BuildGraphOptions& options, + ClientGraph** out) { + VLOG(1) << "BuildGraph"; + mutex_lock l(mu_); + // Lazily initialize the base graph. + if (base_ == nullptr) { + TF_RETURN_IF_ERROR(InitBaseGraph()); + } + + if (!base_ || !placed_) { + return ::tensorflow::errors::Internal( + "There was a problem building the graph."); + } + + std::unique_ptr cgraph(new ClientGraph(ops_)); + CopyGraph(*placed_, &cgraph->graph); + + // Extract the subset of the graph that needs to be run, adding feed/fetch + // ops as needed. + TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( + &cgraph->graph, options.feed_endpoints, options.fetch_endpoints, + options.target_nodes, device_set_->client_device()->attributes())); + + // Copy the extracted graph in order to make its node ids dense, + // since the local CostModel used to record its stats is sized by + // the largest node id. + { + std::unique_ptr dense_copy(new ClientGraph(ops_)); + CopyGraph(cgraph->graph, &dense_copy->graph); + cgraph = std::move(dense_copy); + } + + // TODO(vrv): We should check invariants of the graph here. + + *out = cgraph.release(); + + return Status::OK(); +} + +Status SimpleGraphExecutionState::DeviceIsCompatible( + Node* n, const Device* device) const { + if (!n->def().device().empty()) { + DeviceNameUtils::ParsedName pname; + if (!DeviceNameUtils::ParseFullName(n->def().device(), &pname)) { + return AttachDef( + errors::InvalidArgument("Malformed device specification '", + n->def().device(), "'"), + n->def()); + } + std::vector devices; + device_set_->FindMatchingDevices(pname, &devices); + for (auto d : devices) { + if (d == device) { + return Status::OK(); + } + } + + return AttachDef( + errors::InvalidArgument( + "Specified device '", n->def().device(), + "' not compatible with device of ref connection: ", device->name()), + n->def()); + } + return Status::OK(); +} + +Status SimpleGraphExecutionState::SimplePlacement(Graph* graph) { + SimplePlacer placer(graph, device_set_, &node_name_to_cost_id_map_, + session_options_); + // TODO(mrry): Consider making the SimplePlacer cancelable. + return placer.Run(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/simple_graph_execution_state.h b/tensorflow/core/distributed_runtime/simple_graph_execution_state.h new file mode 100644 index 00000000000..6d065437d87 --- /dev/null +++ b/tensorflow/core/distributed_runtime/simple_graph_execution_state.h @@ -0,0 +1,156 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/distributed_runtime/build_graph_options.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +class SessionOptions; +class StepStats; +class Timeline; + +struct SimpleGraphExecutionStateOptions { + const DeviceSet* device_set = nullptr; + const SessionOptions* session_options = nullptr; +}; + +// A ClientGraph is simply a sub-graph of the full graph as induced by +// BuildGraphOptions. +struct ClientGraph { + Graph graph; + explicit ClientGraph(const OpRegistryInterface* ops) : graph(ops) {} + int32 placement_version; +}; + +// SimpleGraphExecutionState is responsible for generating an +// executable ClientGraph from the original GraphDef that specifies +// the complete graph and from BuildGraphOptions which specifies +// input/output nodes. +// +// An executable Graph differs from a GraphDef by being Placed, +// meaning that each Node is assigned to a single Device in the +// available set. +// +// When SimpleGraphExecutionState is first constructed it instantiates +// a full Graph from the provided GraphDef, and places it, using only +// the static device assignments from the GraphDef. Nodes without are +// currently placed in a very naive way. Since stateful Nodes cannot +// be moved after initial placement, it is important that stateful +// Nodes get sensible initial device assignments in the graph +// definition. +// +// Subsequently, SimpleGraphExecutionState generates a ClientGraph on +// demand, which is a sub-graph of the latest placement of the full +// Graph. MasterSession uses such a ClientGraph to execute one or +// more similar client requests. +// +// SimpleGraphExecutionState is thread-safe. + +class SimpleGraphExecutionState { + public: + SimpleGraphExecutionState(const OpRegistryInterface* ops, + const SimpleGraphExecutionStateOptions& options); + + virtual ~SimpleGraphExecutionState(); + + // Initializes the SimpleGraphExecutionState with 'graph_def'. Can only be + // called once on an original SimpleGraphExecutionState. Callee may modify + // 'graph_def'. + Status Create(GraphDef* graph_def); + + // Creates a new SimpleGraphExecutionState representing the + // concatenation of this graph, and the graph defined by + // "extension_def". The same name may not be used to define a node + // in both this graph and "extension_def". + // + // If successful, returns OK and the caller takes ownership of "*out". + // Otherwise returns an error and does not modify "*out". + // + // NOTE(mrry): This method respects the placement of stateful nodes in + // in *this, but currently does not transfer any other placement + // or cost model information to the new graph. + Status Extend(const GraphDef& extension_def, + SimpleGraphExecutionState** out) const; + + // Builds a ClientGraph (a sub-graph of the full graph as induced by + // the Node set specified in "options"). If successful, returns OK + // and the caller takes the ownership of "*out". Otherwise, returns + // an error. + Status BuildGraph(const BuildGraphOptions& options, ClientGraph** out); + + // Returns OK if the named node is found in the placed full graph owned + // by this execution_state, and sets *out to the NodeDef for that node. + // It may not exist if name is of a Node added for a particular subgraph + // execution, e.g. a send, recv or feed node. + Status GlobalNodeDefByName(const string& name, NodeDef* out); + + private: + mutable mutex mu_; + + Status InitBaseGraph() EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status PreliminaryPlace(const Graph& graph) EXCLUSIVE_LOCKS_REQUIRED(mu_); + void FreezeStatefulNodes(bool is_prelim) EXCLUSIVE_LOCKS_REQUIRED(mu_); + void PlaceStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status DoPlace(Graph* graph); + Status SimplePlacement(Graph* graph); + // Return an OK status if "n" can be assigned to "device". + Status DeviceIsCompatible(Node* n, const Device* device) const; + + const OpRegistryInterface* const ops_; // Not owned + GraphDef original_graph_def_; // Immutable after ctor. + const DeviceSet* device_set_; // Not owned + const SessionOptions* session_options_; // Not owned + + // Original graph before we make any placement decisions. + Graph* base_ GUARDED_BY(mu_); + + // Full graph, placed on the complete set of devices, as a whole. + Graph* placed_ GUARDED_BY(mu_); + + // Map of placed stateful nodes, i.e. nodes for which is_stateful() + // is true, such as "params" and "queue" nodes. Once placed these + // nodes can not be moved to a different device. Maps node names to + // device names. + typedef std::unordered_map PlaceMap; + PlaceMap stateful_placements_ GUARDED_BY(mu_); + std::vector missing_stateful_placements_ GUARDED_BY(mu_); + + // Map from name to Node for the full graph in placed_. + NodeNameToCostIdMap node_name_to_cost_id_map_; + + TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_ diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h new file mode 100644 index 00000000000..5c20636deaf --- /dev/null +++ b/tensorflow/core/distributed_runtime/worker_cache.h @@ -0,0 +1,75 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_ + +#include +#include + +#include "tensorflow/core/distributed_runtime/worker_interface.h" // for CallOptions +#include "tensorflow/core/framework/device_attributes.pb.h" // for BusAdjacency +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +typedef std::function StatusCallback; + +class ChannelCache; +class StepStats; +class WorkerInterface; + +class WorkerCacheInterface { + public: + virtual ~WorkerCacheInterface() {} + + // Updates *workers with strings naming the remote worker tasks to + // which open channels have been established. + virtual void ListWorkers(std::vector* workers) = 0; + + // If "target" names a remote task for which an RPC channel exists + // or can be constructed, returns a new WorkerInterface object + // wrapping that channel. Ownership passes to the caller. + // TODO(tucker): rename this to CreateWorker() or something that + // makes it more obvious this is a constructor that transfers + // ownership, not a cache lookup. + virtual WorkerInterface* CreateWorker(const string& target) = 0; + + // Set *ba with the BusAdjacency of the specified remote device + // within its local environment. Returns true if the device bus + // affinity was set, using only locally cached data. Returns false + // if status data for that device was not available. Never blocks. + // TODO(mrry,tucker): Maybe remove. + virtual bool GetDeviceBusNonBlocking(const string& device, + BusAdjacency* ba) = 0; + + // Set *ba with the BusAdjacency of the specified remote device + // within its local environment. Callback gets Status::OK if the + // device bus affinity was set. + // TODO(mrry,tucker): Maybe remove. + virtual void GetDeviceBusAsync(const string& device, BusAdjacency* ba, + StatusCallback done) = 0; + + // Start/stop logging activity. + virtual void SetLogging(bool active) {} + + // Discard any saved log data. + virtual void ClearLogs() {} + + // Return logs for the identified step in *ss. Any returned data will no + // longer be stored. + virtual bool RetrieveLogs(int64 step_id, StepStats* ss) { return false; } +}; +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_ diff --git a/tensorflow/core/distributed_runtime/worker_cache_logger.cc b/tensorflow/core/distributed_runtime/worker_cache_logger.cc new file mode 100644 index 00000000000..bd523ae03a1 --- /dev/null +++ b/tensorflow/core/distributed_runtime/worker_cache_logger.cc @@ -0,0 +1,110 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/worker_cache_logger.h" + +#include "tensorflow/core/common_runtime/step_stats_collector.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { +// Maximum number of step_ids for which RPC logs can be maintained. +// TODO(mrry): Make this configurable if necessary. +const int32 kWorkerCacheLoggerLimit = 1 << 10; +} // namespace + +void WorkerCacheLogger::SetLogging(bool v) { + mutex_lock l(count_mu_); + if (v) { + ++want_logging_count_; + } else { + --want_logging_count_; + // If RPCs get cancelled, it may be possible for the count + // to go negative. This should not be a fatal error, since + // logging is non-critical. + if (want_logging_count_ < 0) want_logging_count_ = 0; + } +} + +void WorkerCacheLogger::ClearLogs() { + mutex_lock l(mu_); + ClearLogsWithLock(); +} + +void WorkerCacheLogger::ClearLogsWithLock() { + for (auto& iter : log_map_) { + delete iter.second.collector; + } + log_map_.clear(); +} + +bool WorkerCacheLogger::RetrieveLogs(int64 step_id, StepStats* ss) { + mutex_lock l(mu_); + LogMap::iterator iter = log_map_.find(step_id); + if (iter != log_map_.end()) { + iter->second.collector->Swap(ss); + delete iter->second.collector; + log_map_.erase(iter); + return true; + } + return false; +} + +void WorkerCacheLogger::Save(const string& device, int64 step_id, + NodeExecStats* ns) { + mutex_lock l(mu_); + StepLog* sl = &log_map_[step_id]; + if (!sl->collector) { + sl->collector = new StepStatsCollector(&sl->step_stats); + } + sl->collector->Save(device, ns); + if (log_map_.size() > kWorkerCacheLoggerLimit) { + // Something's gone wrong. Just empty the cache. + ClearLogsWithLock(); + } +} + +void WorkerCacheLogger::RecordRecvTensor(int64 step_id, int64 start_usecs, + int64 end_usecs, + const string& tensor_name, + const string& src_device, + const string& dst_device, + int64 bytes) { + NodeExecStats* ns = new NodeExecStats; + ns->set_node_name("RecvTensor"); + string byte_string = strings::StrCat("[", bytes, "B] "); + if (bytes >= 0.1 * 1048576.0) { + byte_string = strings::Printf("[%.1fMB] ", bytes / 1048576.0); + } + ns->set_timeline_label(strings::StrCat(byte_string, tensor_name, " from ", + src_device, " to ", dst_device)); + ns->set_all_start_micros(start_usecs); + ns->set_op_start_rel_micros(0); + ns->set_op_end_rel_micros(end_usecs - start_usecs); + NodeOutput* no = ns->add_output(); + no->set_slot(0); + // TODO(tucker): Maybe set the dimensions too, but then they'll + // need to be passed in. + no->mutable_tensor_description() + ->mutable_allocation_description() + ->set_requested_bytes(bytes); + Save(dst_device, step_id, ns); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/worker_cache_logger.h b/tensorflow/core/distributed_runtime/worker_cache_logger.h new file mode 100644 index 00000000000..46ba8a33bac --- /dev/null +++ b/tensorflow/core/distributed_runtime/worker_cache_logger.h @@ -0,0 +1,81 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_ + +#include +#include + +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +class StepStatsCollector; + +// WorkerCacheLogger is a thread-safe utility for use by a WorkerCache +// to optionally log some selected RPC activity. A single instance +// should be owned by a WorkerCache, for use by its RemoteWorker +// instances. + +class WorkerCacheLogger { + public: + // Start/Stop logging activity. This function increments/decrements + // a counter so that if two separate steps turn logging on/off, + // logging should be on for the union of the durations of both, + // regardless of relative timing. + void SetLogging(bool v); + + // Discard any saved log data. + void ClearLogs(); + + // Return logs for the identified step in *ss. Any returned data will no + // longer be stored. Returns true iff *ss was modified. + bool RetrieveLogs(int64 step_id, StepStats* ss); + + // Return true if there is any outstanding request for logging on + // the RPC channels. + bool LoggingActive() { + mutex_lock l(count_mu_); + return want_logging_count_ > 0; + } + + // Generates a NodeExecStats record with the given data, and saves for + // later retrieval by RetrieveLogs(). + void RecordRecvTensor(int64 step_id, int64 start_usecs, int64 end_usecs, + const string& tensor_name, const string& src_device, + const string& dst_device, int64 bytes); + + private: + mutex count_mu_; + int32 want_logging_count_ GUARDED_BY(count_mu_); + + struct StepLog { + StepStats step_stats; + StepStatsCollector* collector; + }; + typedef std::unordered_map LogMap; + mutex mu_; + LogMap log_map_ GUARDED_BY(mu_); + + // Records "ns" in log_map_ under the given device and step. + void Save(const string& device, int64 step_id, NodeExecStats* ns); + + void ClearLogsWithLock() EXCLUSIVE_LOCKS_REQUIRED(mu_); +}; +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_ diff --git a/tensorflow/core/distributed_runtime/worker_cache_partial.cc b/tensorflow/core/distributed_runtime/worker_cache_partial.cc new file mode 100644 index 00000000000..62c73b5fd92 --- /dev/null +++ b/tensorflow/core/distributed_runtime/worker_cache_partial.cc @@ -0,0 +1,98 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/worker_cache_partial.h" + +#include "tensorflow/core/distributed_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +bool WorkerCachePartial::GetDeviceBusNonBlocking(const string& device_name, + BusAdjacency* ba) { + mutex_lock lock(mu_); // could use reader lock + const auto& iter = device_status_cache_.find(device_name); + if (iter != device_status_cache_.end()) { + *ba = iter->second.bus_adjacency(); + return true; + } + return false; +} + +void WorkerCachePartial::GetDeviceBusAsync(const string& device_name, + BusAdjacency* ba, + StatusCallback done) { + if (!GetDeviceBusNonBlocking(device_name, ba)) { + // If cache entry was empty, make one try to fill it by RPC. + SchedClosure([this, &device_name, ba, done]() { + Status s = RefreshDeviceStatus(device_name); + if (s.ok()) { + if (!GetDeviceBusNonBlocking(device_name, ba)) { + mutex_lock lock(mu_); + const auto& iter = device_status_cache_.find(device_name); + if (iter == device_status_cache_.end()) { + s = errors::Unavailable("No known remote device: ", device_name); + } else { + s = errors::Internal("Failed to find bus_adjacency for ", + device_name); + } + } + } + done(s); + }); + return; + } + done(Status::OK()); +} + +Status WorkerCachePartial::RefreshDeviceStatus(const string& device_name) { + string task; + string device; + Status s; + if (!DeviceNameUtils::SplitDeviceName(device_name, &task, &device)) { + s = errors::InvalidArgument("Bad device name to RefreshDeviceStatus: ", + device_name); + } + std::unique_ptr rwi(CreateWorker(task)); + if (s.ok() && !rwi.get()) { + s = errors::Internal("RefreshDeviceStatus, unknown worker task: ", task); + } + + if (s.ok()) { + GetStatusRequest req; + GetStatusResponse resp; + s = rwi->GetStatus(&req, &resp); + if (s.ok()) { + mutex_lock lock(mu_); + for (auto& dev_attr : resp.device_attributes()) { + device_status_cache_[dev_attr.name()] = dev_attr; + } + } + } + return s; +} + +void WorkerCachePartial::FlushStatusCache() { + mutex_lock lock(mu_); + device_status_cache_.clear(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/worker_cache_partial.h b/tensorflow/core/distributed_runtime/worker_cache_partial.h new file mode 100644 index 00000000000..5d8a56648d6 --- /dev/null +++ b/tensorflow/core/distributed_runtime/worker_cache_partial.h @@ -0,0 +1,56 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_ + +#include +#include + +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +// Implements the part of the interface that caches and returns remote +// device status attributes. +class WorkerCachePartial : public WorkerCacheInterface { + public: + bool GetDeviceBusNonBlocking(const string& device, BusAdjacency* ba) override; + + void GetDeviceBusAsync(const string& device, BusAdjacency* ba, + StatusCallback) override; + + ~WorkerCachePartial() override {} + + // Clear all entries from the DeviceStatus cache. + void FlushStatusCache(); + + private: + mutex mu_; + + // Initiate a GetStatusAsync to the remote task named by "task", and + // update the cache with all the DeviceAttributes reported. + Status RefreshDeviceStatus(const string& device_name); + + typedef std::unordered_map StatusMap; + StatusMap device_status_cache_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_ diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h new file mode 100644 index 00000000000..d26462570e9 --- /dev/null +++ b/tensorflow/core/distributed_runtime/worker_env.h @@ -0,0 +1,62 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace thread { +class ThreadPool; +} // namespace thread + +class DeviceMgr; +class Env; +class GraphMgr; +class RendezvousMgrInterface; +class WorkerCacheInterface; + +// The worker environment class, which holds a bag of pointers to +// per-worker singletons. +// +// WorkerEnv does not own its member pointers. +struct WorkerEnv { + Env* env = nullptr; + + // The name of the worker. E.g., /job:mnist/replica:1/task:0. + string worker_name; + + // Object from which WorkerInterface instances can be obtained. + WorkerCacheInterface* worker_cache = nullptr; + + // device_mgr manages local devices (cpu and gpu). The WorkerService + // is the network interface for managed devices. + DeviceMgr* device_mgr = nullptr; + + // graph_mgr keeps track of registered graphs of this worker. + GraphMgr* graph_mgr = nullptr; + + // A set of rendezvous keyed by step ids. + RendezvousMgrInterface* rendezvous_mgr = nullptr; + + // A pool of threads for scheduling compute work. + thread::ThreadPool* compute_pool = nullptr; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h new file mode 100644 index 00000000000..6e68d300e6d --- /dev/null +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -0,0 +1,129 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ + +#include + +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +// Status callback. +typedef std::function StatusCallback; + +// Allocator callback for out-of-band transfers. +class TensorShape; +typedef std::function + TensorBufAllocator; + +// Interface for talking with the TensorFlow Worker service. +class WorkerInterface { + public: + virtual ~WorkerInterface() {} + + virtual void GetStatusAsync(const GetStatusRequest* request, + GetStatusResponse* response, + StatusCallback done) = 0; + + virtual void RegisterGraphAsync(const RegisterGraphRequest* request, + RegisterGraphResponse* response, + StatusCallback done) = 0; + + virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request, + DeregisterGraphResponse* response, + StatusCallback done) = 0; + + virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request, + RunGraphResponse* response, + StatusCallback done) = 0; + + virtual void CleanupGraphAsync(const CleanupGraphRequest* request, + CleanupGraphResponse* response, + StatusCallback done) = 0; + + virtual void CleanupAllAsync(const CleanupAllRequest* request, + CleanupAllResponse* response, + StatusCallback done) = 0; + + virtual void RecvTensorAsync(CallOptions* opts, + const RecvTensorRequest* request, + RecvTensorResponse* response, + TensorBufAllocator allocator, + StatusCallback done) = 0; + + virtual void LoggingAsync(const LoggingRequest* request, + LoggingResponse* response, StatusCallback done) = 0; + + virtual void TracingAsync(const TracingRequest* request, + TracingResponse* response, StatusCallback done) = 0; + + Status GetStatus(const GetStatusRequest* request, + GetStatusResponse* response) { + return CallAndWait(&ME::GetStatusAsync, request, response); + } + + Status RegisterGraph(const RegisterGraphRequest* request, + RegisterGraphResponse* response) { + return CallAndWait(&ME::RegisterGraphAsync, request, response); + } + + Status DeregisterGraph(const DeregisterGraphRequest* request, + DeregisterGraphResponse* response) { + return CallAndWait(&ME::DeregisterGraphAsync, request, response); + } + + Status CleanupGraph(const CleanupGraphRequest* request, + CleanupGraphResponse* response) { + return CallAndWait(&ME::CleanupGraphAsync, request, response); + } + + Status CleanupAll(const CleanupAllRequest* request, + CleanupAllResponse* response) { + return CallAndWait(&ME::CleanupAllAsync, request, response); + } + + Status Logging(const LoggingRequest* request, LoggingResponse* response) { + return CallAndWait(&ME::LoggingAsync, request, response); + } + + Status Tracing(const TracingRequest* request, TracingResponse* response) { + return CallAndWait(&ME::TracingAsync, request, response); + } + + private: + typedef WorkerInterface ME; + + template + Status CallAndWait(Method func, const Req* req, Resp* resp) { + Status ret; + Notification n; + (this->*func)(req, resp, [&ret, &n](const Status& s) { + ret = s; + n.Notify(); + }); + n.WaitForNotification(); + return ret; + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc index 92a34001855..bd21ac8e345 100644 --- a/tensorflow/core/framework/load_library.cc +++ b/tensorflow/core/framework/load_library.cc @@ -65,7 +65,7 @@ Status LoadLibrary(const char* library_filename, void** result, string str; GetOpList(&str); char* str_buf = reinterpret_cast(operator new(str.length())); - strncpy(str_buf, str.data(), str.length()); + memcpy(str_buf, str.data(), str.length()); *buf = str_buf; *len = str.length(); diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index efee64a7a84..633441f31bf 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -25,37 +25,58 @@ def tf_deps(deps, suffix): return tf_deps -def tf_proto_library(name, srcs = [], has_services = False, - deps = [], visibility = [], testonly = 0, - cc_api_version = 2, go_api_version = 2, - java_api_version = 2, - py_api_version = 2): +def tf_proto_library_cc(name, srcs = [], has_services = None, + deps = [], visibility = [], testonly = 0, + cc_libs = [], + cc_stubby_versions = None, + cc_grpc_version = None, + cc_api_version = 2, go_api_version = 2, + java_api_version = 2, + py_api_version = 2): native.filegroup(name=name + "_proto_srcs", srcs=srcs + tf_deps(deps, "_proto_srcs"), testonly=testonly,) + use_grpc_plugin = None + if cc_grpc_version: + use_grpc_plugin = True cc_proto_library(name=name + "_cc", srcs=srcs + tf_deps(deps, "_proto_srcs"), deps=deps + ["//google/protobuf:cc_wkt_protos"], - cc_libs = ["//google/protobuf:protobuf"], + cc_libs = cc_libs + ["//google/protobuf:protobuf"], + use_grpc_plugin = use_grpc_plugin, testonly=testonly, visibility=visibility,) - py_proto_library(name=name + "_py", - srcs=srcs + tf_deps(deps, "_proto_srcs"), - srcs_version="PY2AND3", - deps=deps + ["//google/protobuf:protobuf_python"], - testonly=testonly, - visibility=visibility,) - -def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0): +def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0, + srcs_version="PY2AND3"): py_proto_library(name = name + "_py", srcs = srcs, - srcs_version = "PY2AND3", + srcs_version = srcs_version, deps = deps, visibility = visibility, testonly = testonly) +def tf_proto_library(name, srcs = [], has_services = None, + deps = [], visibility = [], testonly = 0, + cc_libs = [], + cc_api_version = 2, go_api_version = 2, + java_api_version = 2, + py_api_version = 2): + tf_proto_library_cc(name=name, + srcs=srcs + tf_deps(deps, "_proto_srcs"), + deps=deps, + cc_libs=cc_libs, + testonly=testonly, + visibility=visibility,) + + tf_proto_library_py(name=name, + srcs=srcs + tf_deps(deps, "_proto_srcs"), + srcs_version="PY2AND3", + deps=deps + ["//google/protobuf:protobuf_python"], + testonly=testonly, + visibility=visibility,) + def tf_additional_lib_srcs(): return [ "platform/default/*.h", diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto new file mode 100644 index 00000000000..e46581bdab1 --- /dev/null +++ b/tensorflow/core/protobuf/master.proto @@ -0,0 +1,190 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; +option java_outer_classname = "DistributedRuntimeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +import "tensorflow/core/framework/config.proto"; +import "tensorflow/core/framework/device_attributes.proto"; +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/tensor.proto"; + +//////////////////////////////////////////////////////////////////////////////// +// +// CreateSession method request/response protos. +// +//////////////////////////////////////////////////////////////////////////////// + +message CreateSessionRequest { + // The initial graph definition. + GraphDef graph_def = 1; + + // Configuration options. + ConfigProto config = 2; +} + +message CreateSessionResponse { + // The session handle to be used in subsequent calls for the created session. + // + // The client must arrange to call CloseSession with this returned + // session handle to close the session. + string session_handle = 1; + + // The initial version number for the graph, to be used in the next call + // to ExtendSession. + int64 graph_version = 2; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// ExtendSession method request/response protos. +// +// The "graph_def" specifies a set of nodes to be added to the session's graph. +// +// A typical "graph_def" will contain: +// +// * Zero or more new nodes with names that do not exist in the server-side +// graph. These will be added to the graph. +// +// PRECONDITION: The server-side current version is req.current_version. +// None of the names in req.graph_def appeared in previous successful calls to +// CreateSession or ExtendSession with the same session_handle. +// POSTCONDITION: The server-side current version is resp.new_version. +// +//////////////////////////////////////////////////////////////////////////////// + +message ExtendSessionRequest { + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + string session_handle = 1; + + // REQUIRED: The nodes to be added to the session's graph. If any node has + // the same name as an existing node, the operation will fail with + // ILLEGAL_ARGUMENT. + GraphDef graph_def = 2; + + // REQUIRED: The version number of the graph to be extended. This will be + // tested against the current server-side version number, and the operation + // will fail with FAILED_PRECONDITION if they do not match. + int64 current_graph_version = 3; +} + +message ExtendSessionResponse { + // TODO(mrry): Return something about the operation? + + // The new version number for the extended graph, to be used in the next call + // to ExtendSession. + int64 new_graph_version = 4; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// RunStep method request/response protos. +// +// The caller should provide the feeds needed by the graph and specify +// what nodes should be fetched. +// +//////////////////////////////////////////////////////////////////////////////// + +// A pair of tensor name and tensor values. +message NamedTensorProto { + // Name of the tensor. + string name = 1; + + // The client can populate a TensorProto using a tensorflow::Tensor`, or + // directly using the protobuf field accessors. + // + // The client specifies whether the returned tensor values should be + // filled tensor fields (float_val, int_val, etc.) or encoded in a + // compact form in tensor.tensor_content. + TensorProto tensor = 2; +} + +message RunStepRequest { + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + string session_handle = 1; + + // Tensors to be fed in the step. Each feed is a named tensor. + repeated NamedTensorProto feed = 2; + + // Fetches. A list of tensor names. The caller expects a tensor to + // be returned for each fetch[i] (see RunStepResponse.tensor). The + // order of specified fetches does not change the execution order. + repeated string fetch = 3; + + // Target Nodes. A list of node names. The named nodes will be run + // to but their outputs will not be fetched. + repeated string target = 4; +} + +message RunStepResponse { + // NOTE: The order of the returned tensors may or may not match + // the fetch order specified in RunStepRequest. + repeated NamedTensorProto tensor = 1; + + // TODO(mrry): Optionally aggregate StepStats in some form here. +} + +//////////////////////////////////////////////////////////////////////////////// +// +// CloseSession method request/response protos. +// +//////////////////////////////////////////////////////////////////////////////// + +message CloseSessionRequest { + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + string session_handle = 1; +} + +message CloseSessionResponse { +} + +message ResetRequest { + // A list of container names, which may be empty. + // + // If 'container' is not empty, releases resoures in the given + // containers in all devices. + // + // If 'container' is empty, releases resources in the default + // container in all devices. + repeated string container = 1; +} + +message ResetResponse { +} + +//////////////////////////////////////////////////////////////////////////////// +// +// ListDevices method request/response protos. +// +// Returns information about the TensorFlow devices that are available +// to this master. +// +//////////////////////////////////////////////////////////////////////////////// + +message ListDevicesRequest { +} + +message ListDevicesResponse { + repeated DeviceAttributes local_device = 1; + repeated DeviceAttributes remote_device = 2; +} diff --git a/tensorflow/core/protobuf/master_service.proto b/tensorflow/core/protobuf/master_service.proto new file mode 100644 index 00000000000..13b0a97b11f --- /dev/null +++ b/tensorflow/core/protobuf/master_service.proto @@ -0,0 +1,105 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow.grpc; +option java_outer_classname = "MasterServiceProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +import "tensorflow/core/protobuf/master.proto"; + +//////////////////////////////////////////////////////////////////////////////// +// +// MasterService defines a TensorFlow service with which a client can +// interact to execute a distributed TensorFlow computation. +// +// A master service keeps track of multiple "master sessions". Each +// session encapsulates a computation graph and its associated state, +// and typically corresponds to a single "client session" (e.g. a +// `tensorflow::Session` instance). +// +// A session is responsible for the following: +// * assigning each node to a device (locally or remotely) using a +// placement algorithm. This may make decisions based on collected +// statistics from the workers in the system (e.g., memory usage, +// bandwidth consumption, etc.) +// +// * inserting intermediate nodes and edges to support cross-device +// and cross-process data flows and resource management. +// +// * issuing commands to workers to execute the subgraphs associated +// with those workers. +// +// Typically, a client carries out an iterative computation +// (e.g. training) by invoking RPCs against the master in a +// client-side loop. The client first creates a client session that +// connects to a particular master (using gRPC for example). The +// master creates a corresponding master session that is hosted on +// the master and caches state between the client's invocations. +// +// After the session is established, the master returns an opaque +// handle to the client that can be used to associate the client and +// master sessions. +// +// The client may send an initial graph to the master in the +// CreateSession call, and add nodes to the graph using ExtendSession. +// +// The most frequent operation a master is "RunStep", which implements +// the `Session::Run()` API. It supports feeding in arguments, +// executing a dataflow computation, and fetching arguments. +// +// Finally, when the client no longer needs the session, it should +// close the session by invoking CloseSession, which allows the master +// to reclaim resources associated with the session. The master may +// implement a garbage collection scheme that closes sessions that +// have been inactive for some time. +// +// For example, the following pseudo-code illustrates how a client +// interacts with a master: +// +// stub = NewStub("/job:mnist/replica:0/task:0") +// {handle} = stub->CreateSession({graph_def}) +// do { +// stub->RunStep({handle, {feeds}, {fetches}}) +// // The client can evaluate a predicate locally, based on the +// // result of `fetches`, to determine whether to terminate. For +// // example, it might fetch the loss and evaluate whether it is less +// // than some threshold. +// } whlie (!should_stop({fetches})); +// stub->CloseSession({handle}) +// +//////////////////////////////////////////////////////////////////////////////// + +service MasterService { + // Creates a session. + rpc CreateSession(CreateSessionRequest) returns (CreateSessionResponse); + + // Extends a session. + rpc ExtendSession(ExtendSessionRequest) returns (ExtendSessionResponse); + + // Drives the graph computation. + rpc RunStep(RunStepRequest) returns (RunStepResponse); + + // Closes a session. + rpc CloseSession(CloseSessionRequest) returns (CloseSessionResponse); + + // List the devices usable by the master. + rpc ListDevices(ListDevicesRequest) returns (ListDevicesResponse); + + // Close all existing sessions. + rpc Reset(ResetRequest) returns (ResetResponse); +} diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto new file mode 100644 index 00000000000..bb01b65d8b9 --- /dev/null +++ b/tensorflow/core/protobuf/worker.proto @@ -0,0 +1,311 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; +option java_outer_classname = "WorkerProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +import "google/protobuf/any.proto"; +import "tensorflow/core/framework/config.proto"; +import "tensorflow/core/framework/step_stats.proto"; +import "tensorflow/core/framework/device_attributes.proto"; +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/tensor.proto"; + +//////////////////////////////////////////////////////////////////////////////// +// +// GetStatus method request/response messages +// +//////////////////////////////////////////////////////////////////////////////// + +message GetStatusRequest { +} + +message GetStatusResponse { + repeated DeviceAttributes device_attributes = 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// RegisterGraph method request/response messages +// +// For each session, after the master placed every node on a device, +// it partitions the whole graph into many subgraphs. All the nodes in +// a subgraph were in the same worker, but potentially on many devices +// owned by that worker (e.g. cpu0, plus gpu0, gpu1, ..., gpu7). The +// master registers subgraphs for a worker before running any steps. A +// successful registration returns a graph handle to be used in latter +// RunGraph requests. +// +//////////////////////////////////////////////////////////////////////////////// + +message RegisterGraphRequest { + // Subgraphs are scoped within one session. + string session_handle = 1; + + // "graph_def" has the subgraph of nodes for this worker, with each node + // having its device_name filled in. + GraphDef graph_def = 2; + + // True iff the graph (before partitioning) contains control flow nodes. + // + // As of 01/11/2015, this is no longer set by clients. + bool has_control_flow = 3 [deprecated = true]; + + // Configuration options for the session in which this graph was created. + GraphOptions graph_options = 4; +} + +message RegisterGraphResponse { + // If the registration succeeds, returns an opaque graph_handle to + // the master. The master calls RunGraph with graph_handle to + // compute different steps. + string graph_handle = 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// DeregisterGraph method request/response messages +// +// The master deregisters the given graph_handle when the graph is no +// longer needed (e.g., the overall graph is re-scheduled and nodes +// are re-placed). +// +// The worker deregisters a graph_handle automatically according to on +// a TTL-base policy in case of master restarts. +// +//////////////////////////////////////////////////////////////////////////////// + +message DeregisterGraphRequest { + // REQUIRED: graph_handle must be returned by a RegisterGraph call + // to the same WorkerService. + string graph_handle = 1; +} + +message DeregisterGraphResponse { + // TODO(mrry): Optionally add summary stats for the graph. +} + +//////////////////////////////////////////////////////////////////////////////// +// +// CleanupAll method request/response messages +// +//////////////////////////////////////////////////////////////////////////////// + +message CleanupAllRequest { + // A list of container names. + // + // If 'container' is not empty, releases resoures in the given + // containers in all devices. + // + // If 'container' is empty, releases resources in the default + // container in all devices. + repeated string container = 1; +} + +message CleanupAllResponse { +} + +//////////////////////////////////////////////////////////////////////////////// +// +// RunGraph request / response messages +// +// The worker executes all subgraphs registered under graph_handle. +// RunGraph returns after the execution finishes or an error is +// encountered. +// +//////////////////////////////////////////////////////////////////////////////// + +// A pair of tensor name and tensor values. +message NamedTensor { + // The name of the named tensor. + string key = 1; + + // The value of the named tensor. + TensorProto val = 2; +} + +// Options specific to the execution of a single step. +message ExecutorOpts { + bool record_costs = 1; + bool record_timeline = 3; +}; + +message RunGraphRequest { + // REQUIRED: graph_handle must be returned by a RegisterGraph call + // to the same WorkerService. + string graph_handle = 1; + + // A unique ID to distinguish different runs of the same graph. + // + // The master generates a global unique `step_id` to dinstinguish + // different runs of the graph computation. Subgraphs communicate + // (e.g., send/recv ops) with each other using `step_id` to + // distinguish tensors generated by different runs. + int64 step_id = 2; + + // Options for this step. + ExecutorOpts exec_opts = 5; + + // Runs the graph. + // + // Sends the tensors in "send" into the graph before the run and + // fetches the keys into `RunGraphResponse.recv` after the run. + repeated NamedTensor send = 3; + repeated string recv_key = 4; +} + +message RunGraphResponse { + // A list of tensors corresponding to those requested by + // `RunGraphRequest.recv_key`. + repeated NamedTensor recv = 1; + + // If the request asked for execution stats, these are returned here. + StepStats step_stats = 2; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// CleanupGraph method request/response messages +// +// After the master receives RunGraph responses from all workers, the +// master instructs every worker to cleanup any remaining state of a +// step (e.g. tensors buffered by a `Send` op but not picked up by +// other workers). The master does not necessarily need to wait for +// completion of CleanupGraph calls. +// +// Workers should cleanup step states automatically according to a +// TTL-based policy in case of master restarts. +// +//////////////////////////////////////////////////////////////////////////////// + +message CleanupGraphRequest { + int64 step_id = 1; +} + +message CleanupGraphResponse { +} + +//////////////////////////////////////////////////////////////////////////////// +// +// RecvTensor method request/response messages +// +//////////////////////////////////////////////////////////////////////////////// + +message RecvTensorRequest { + // The step in which the tensor will be produced. + // + // REQUIRED: This must eventually correspond to the `step_id` passed + // into a RunGraph call on the same WorkerService. + int64 step_id = 1; + + // A key that identifies the tensor to be received. + string rendezvous_key = 2; + + // If true, use an out-of-band DMA mechanism to transfer the + // received tensor. + bool dma_ok = 3; + // NIC bus preference on the request originator side + BusAdjacency client_bus_adjacency = 4; + // NIC bus preference on the request receiver side + BusAdjacency server_bus_adjacency = 5; +} + +message RecvTensorResponse { + // The tensor as a proto. + TensorProto tensor = 1; + + // If true, this tensor was the output of a dead node, and the + // content is invalid. + bool is_dead = 2; + + // The time at which tensor was available and started to be returned. + int64 send_start_micros = 3; + + // Optional additional information about how to receive the tensor, + // in the event that `RecvTensorRequest.dma_ok` was true. + google.protobuf.Any transport_options = 4; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// Logging method request/response messages +// +// NOTE(mrry): This feature is not supported in the open-source +// version, and these messages are expected to change. +// +//////////////////////////////////////////////////////////////////////////////// + +// Out-of-band request to begin or end logging, or +// to retrieve logs for particular steps. +message LoggingRequest { + // If true, RPC logging will be activated. + bool rpc_logging = 1; + + // If true, discard any saved logging data (for all steps). + bool clear = 2; + + // When set, requests all saved log data pertaining to the step. + // Any log data retrieved is eliminated from the store and cannot be + // retrieved again. + repeated int64 fetch_step_id = 3; +} + +message LabeledStepStats { + int64 step_id = 1; + StepStats step_stats = 2; +} + +message LoggingResponse { + repeated LabeledStepStats step = 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// Tracing method request/response messages +// +// NOTE(mrry): This feature is not supported in the open-source +// version, and these messages are expected to change. +// +//////////////////////////////////////////////////////////////////////////////// + +message TraceOpts { + // Length of the trace to be taken, in seconds. + double duration = 1; + // If true, capture step profile locally in each worker. Currently + // unimplemented. + bool use_step_profiler = 2; + // If true, capture kernel events from each worker. + bool use_kernel_profiler = 3; + // If true, capture extended profiling events from TensorFlow process. + bool use_extended_profiler = 4; + // If true, capture GPU profiling events locally on each + // machine. Currently unimplemented. + bool use_gpu_profiler = 5; + // If true, collect sampled profile events. Currently unimplemented. + bool use_sample_profiler = 6; +} + +// Out-of-band request to configure distributed tracing. +message TracingRequest { + TraceOpts options = 1; +} + +message TracingResponse { +} diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto new file mode 100644 index 00000000000..2699e639db8 --- /dev/null +++ b/tensorflow/core/protobuf/worker_service.proto @@ -0,0 +1,67 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow.grpc; +option java_outer_classname = "WorkerServiceProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +import "tensorflow/core/protobuf/worker.proto"; + +//////////////////////////////////////////////////////////////////////////////// +// +// WorkerService defines a TensorFlow service that executes dataflow +// graphs on a set of local devices, on behalf of a MasterService. +// +// A worker service keeps track of multiple "registered graphs". Each +// registered graph is a subgraph of a client's graph, corresponding to +// only the nodes that should execute on this worker (and any +// additional nodes necessary for inter-process communication using +// the `RecvTensor` method). +// +//////////////////////////////////////////////////////////////////////////////// + +service WorkerService { + // See worker.proto for details. + rpc GetStatus(GetStatusRequest) returns (GetStatusResponse); + + // See worker.proto for details. + rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse); + + // See worker.proto for details. + rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse); + + // See worker.proto for details. + rpc RunGraph(RunGraphRequest) returns (RunGraphResponse); + + // See worker.proto for details. + rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse); + + // See worker.proto for details. + rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse); + + // See worker.proto for details. + rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) { + // RecvTensor Method + } + + // See worker.proto for details. + rpc Logging(LoggingRequest) returns (LoggingResponse); + + // See worker.proto for details. + rpc Tracing(TracingRequest) returns (TracingResponse); +} diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2f45b074842..58d00d3c7f2 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -920,6 +920,7 @@ tf_py_wrap_cc( ":py_record_writer_lib", ":python_op_gen", ":tf_session_helper", + "//tensorflow/core/distributed_runtime/rpc:grpc_session", "//util/python:python_headers", ], ) diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 96d61a327cc..16d4f287abd 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -750,8 +750,9 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(c_list[1], out[1].decode('utf-8')) def testInvalidTargetFails(self): - with self.assertRaisesRegexp(RuntimeError, - 'Registered factories are {DIRECT_SESSION}'): + with self.assertRaisesRegexp( + RuntimeError, + 'No session factory registered for the given session options.'): session.Session('INVALID_TARGET') def testFetchByNameDifferentStringTypes(self): diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index bc5ba95348a..9db78bb13a6 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -135,7 +135,7 @@ have varying scale, and to aid generalization. @@l2_normalize @@local_response_normalization @@sufficient_statistics -@@aggregate_moments +@@normalize_moments @@moments ## Losses @@ -561,7 +561,7 @@ def sufficient_statistics(x, axes, shift=True, keep_dims=False, name=None): return counts, m_ss, v_ss, shift_value -def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None): +def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): """Calculate the mean and variance of based on the sufficient statistics. Args: @@ -577,7 +577,7 @@ def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None): Returns: Two `Tensor` objects: `mean` and `variance`. """ - with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "aggregate"): + with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "normalize"): divisor = math_ops.inv(counts, name="divisor") if shift is not None: shifted_mean = math_ops.mul(mean_ss, divisor, name="shifted_mean") @@ -620,7 +620,7 @@ def moments(x, axes, name=None, keep_dims=False): axes, keep_dims=keep_dims, name=name) - return aggregate_moments(counts, m_ss, v_ss, shift, name=name) + return normalize_moments(counts, m_ss, v_ss, shift, name=name) def batch_normalization(x, diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 317a0748309..30c79769096 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -826,19 +826,19 @@ class SufficientStatisticsTest(tf.test.TestCase): self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape) -class AggregateMomentsTest(tf.test.TestCase): +class NormalizeMomentsTest(tf.test.TestCase): - def _npAggregateMoments(self, counts, mean_ss, variance_ss, shift): + def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift): mean = mean_ss / counts variance = variance_ss / counts - mean * mean if shift is not None: mean += shift return mean, variance - def _opAggregateMoments(self, counts, mean_ss, variance_ss, shift): - return tf.nn.aggregate_moments(counts, mean_ss, variance_ss, shift) + def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift): + return tf.nn.normalize_moments(counts, mean_ss, variance_ss, shift) - def _testAggregateMoments(self, shape, shift): + def _testNormalizeMoments(self, shape, shift): counts = np.ones([1]).astype(np.float32) mean_ss = np.random.random_sample(shape).astype(np.float32) variance_ss = np.random.random_sample(shape).astype(np.float32) @@ -847,7 +847,7 @@ class AggregateMomentsTest(tf.test.TestCase): shift_v = np.random.random_sample(shape).astype(np.float32) else: shift_v = None - npm, npv = self._npAggregateMoments(counts, mean_ss, variance_ss, shift_v) + npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v) for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu) as sess: tf_counts = tf.constant(counts, name="counts") @@ -857,16 +857,16 @@ class AggregateMomentsTest(tf.test.TestCase): tf_shift_v = tf.constant(shift_v, name="shift") else: tf_shift_v = None - opm, opv = self._opAggregateMoments(tf_counts, tf_mean_ss, + opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss, tf_variance_ss, tf_shift_v) tfm, tfv = sess.run([opm, opv]) self.assertAllClose(npm, tfm, atol=0.000001) self.assertAllClose(npv, tfv, atol=0.000001) - def testAggregateMoments(self): + def testNormalizeMoments(self): for shift in [True, False]: - self._testAggregateMoments([3], shift) - self._testAggregateMoments([2, 3], shift) + self._testNormalizeMoments([3], shift) + self._testNormalizeMoments([2, 3], shift) class MomentsTest(tf.test.TestCase): @@ -971,15 +971,15 @@ class MomentsTest(tf.test.TestCase): """Make sure the output names are stable.""" with self.test_session(): mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False) - self.assertEquals(mean.op.name, "moments/aggregate/mean") - self.assertEquals(var.op.name, "moments/aggregate/variance") + self.assertEquals(mean.op.name, "moments/normalize/mean") + self.assertEquals(var.op.name, "moments/normalize/variance") def testOutputNamesKeep(self): """Make sure the output names are stable.""" with self.test_session(): mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True) - self.assertEquals(mean.op.name, "moments/aggregate/mean") - self.assertEquals(var.op.name, "moments/aggregate/variance") + self.assertEquals(mean.op.name, "moments/normalize/mean") + self.assertEquals(var.op.name, "moments/normalize/variance") class ComputeSampledLogitsTest(tf.test.TestCase):