Remove libverbs code from tensorflow/(contrib)
It has moved to github.com/tensorflow/networking PiperOrigin-RevId: 266019701
This commit is contained in:
parent
7520b333d6
commit
25f08ffe5d
1
.bazelrc
1
.bazelrc
@ -94,7 +94,6 @@ build:sycl_trisycl --define=using_trisycl=true
|
||||
# Options extracted from configure script
|
||||
build:gdr --define=with_gdr_support=true
|
||||
build:ngraph --define=with_ngraph_support=true
|
||||
build:verbs --define=with_verbs_support=true
|
||||
build:numa --define=with_numa_support=true
|
||||
|
||||
# Options to disable default on features
|
||||
|
@ -1504,7 +1504,6 @@ def main():
|
||||
config_info_line('mkl', 'Build with MKL support.')
|
||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||
config_info_line('gdr', 'Build with GDR support.')
|
||||
config_info_line('verbs', 'Build with libverbs support.')
|
||||
config_info_line('ngraph', 'Build with Intel nGraph support.')
|
||||
config_info_line('numa', 'Build with NUMA support.')
|
||||
config_info_line(
|
||||
|
@ -315,12 +315,6 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_verbs_support",
|
||||
define_values = {"with_verbs_support": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_numa_support",
|
||||
define_values = {"with_numa_support": "true"},
|
||||
|
@ -1,162 +0,0 @@
|
||||
# Description:
|
||||
# Verbs RDMA communication interfaces and implementations for TensorFlow.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
|
||||
|
||||
# For platform specific build config
|
||||
load(
|
||||
"//tensorflow/core/platform:default/build_config.bzl",
|
||||
"tf_proto_library_cc",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
filegroup(
|
||||
name = "c_srcs",
|
||||
data = glob([
|
||||
"**/*.cc",
|
||||
"**/*.h",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "verbs_service_proto",
|
||||
srcs = ["verbs_service.proto"],
|
||||
has_services = 1,
|
||||
cc_api_version = 2,
|
||||
visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "verbs_util",
|
||||
srcs = ["verbs_util.cc"],
|
||||
hdrs = ["verbs_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_verbs_service",
|
||||
srcs = ["grpc_verbs_service.cc"],
|
||||
hdrs = ["grpc_verbs_service.h"],
|
||||
deps = [
|
||||
":grpc_verbs_service_impl",
|
||||
":rdma_mgr",
|
||||
":verbs_service_proto_cc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_verbs_service_impl",
|
||||
srcs = ["grpc_verbs_service_impl.cc"],
|
||||
hdrs = ["grpc_verbs_service_impl.h"],
|
||||
deps = [
|
||||
":verbs_service_proto_cc",
|
||||
"//tensorflow:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_verbs_client",
|
||||
srcs = ["grpc_verbs_client.cc"],
|
||||
hdrs = ["grpc_verbs_client.h"],
|
||||
deps = [
|
||||
":grpc_verbs_service_impl",
|
||||
":verbs_service_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:call_options",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rdma_rendezvous_mgr",
|
||||
srcs = ["rdma_rendezvous_mgr.cc"],
|
||||
hdrs = ["rdma_rendezvous_mgr.h"],
|
||||
deps = [
|
||||
":rdma_mgr",
|
||||
":verbs_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "rdma_mgr",
|
||||
srcs = ["rdma_mgr.cc"],
|
||||
hdrs = ["rdma_mgr.h"],
|
||||
deps = [
|
||||
":grpc_verbs_client",
|
||||
":rdma",
|
||||
":verbs_service_proto_cc",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "rdma",
|
||||
srcs = ["rdma.cc"],
|
||||
hdrs = ["rdma.h"],
|
||||
linkopts = select({
|
||||
"//tensorflow:with_verbs_support": ["-libverbs"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":grpc_verbs_client",
|
||||
":verbs_service_proto_cc",
|
||||
":verbs_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "verbs_server_lib",
|
||||
srcs = ["verbs_server_lib.cc"],
|
||||
hdrs = ["verbs_server_lib.h"],
|
||||
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
|
||||
deps = [
|
||||
":grpc_verbs_service",
|
||||
":rdma_mgr",
|
||||
":rdma_rendezvous_mgr",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
@ -1,167 +0,0 @@
|
||||
## How to compile, use and configure RDMA-enabled TensorFlow
|
||||
1. Follow the regular TF compilation instructions. During configure step, if you want ibverbs based RDMA support, answer yes to this question:
|
||||
|
||||
```Do you wish to build TensorFlow with VERBS-RDMA support [y/N]```
|
||||
|
||||
2. To turn on RDMA connection, add the protocol "grpc+verbs" in server definition:
|
||||
|
||||
```server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+verbs') # default protocol is 'grpc'```
|
||||
|
||||
3. RDMA configuration is done by setting the following environment variables:
|
||||
* **RDMA_DEVICE**: The RDMA device name to be used. If not defined by user, a default device with an active port will be set if exists.
|
||||
* **RDMA_DEVICE_PORT**: The port within the selected device. Not relevant if RDMA_DEVICE is not defined. If not defined by user, a default active port will be set if exists.
|
||||
* **RDMA_GID_INDEX**: The GID index of the port. If not defined by user, a default suitable GID index will be set (RoCEV2 is favourable as default).
|
||||
* **RDMA_QP_PKEY_INDEX**: The Pkey for the QP. If not defined by user, the default value is 0.
|
||||
* **RDMA_QP_QUEUE_DEPTH**: TX/RX queue size for the QP. If not defined by user, the default value is 1024.
|
||||
* **RDMA_QP_TIMEOUT**: The retransmission timeout for QPs. If not defined by user, the default value is 14.
|
||||
* **RDMA_QP_RETRY_COUNT**: Number of retransmission for QPs. If not defined by user, the default value is 7.
|
||||
* **RDMA_QP_SL**: Service level configuration for QOS and ECN, valid values are 0-7. If not defined by user, the default value is 0.
|
||||
* **RDMA_QP_MTU**: MTU configuration for the QPs. If not defined by user, the default value is active MTU from query_port.
|
||||
* **RDMA_TRAFFIC_CLASS**: Traffic class configuration for QP, in case of DSCP trust level QoS configuration. If not defined by user, the default value is 0. For more info see [HowTo Configure Trust state on Mellanox Adapters](https://community.mellanox.com/docs/DOC-2866).
|
||||
|
||||
## Overview
|
||||
The design is based on TensorFlow r1.0. An RDMA path is added between servers for tensor transfer (weights, gradients, etc). The existing GRPC path remains and is responsible for "administrative" tasks, such as setting up the RDMA path, exchanging computation graphs, etc.
|
||||
|
||||
During the server setup, an RDMA manager is created to manage low-level RDMA components such as RDMA channel and RDMA adapter, an RDMA rendezvous manager is created to oversee send/recv operations between servers. Following the distributed TensorFlow design philosophy, the send operation is passive, i.e. merely placing a tensor in the local out-going table. It is the receive operation that actually initiates the tensor transfer.
|
||||
|
||||
TensorFlow dynamically allocates memory for tensors that are to be sent or received. This causes difficulty for RDMA operations where pinned memory is required. Few remedies are possible:
|
||||
1. The memory is pinned, transferred, then unpinned for each and every tensor to be transferred. This incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow.
|
||||
2. Buffer is pre-allocated and pinned for each tensor. This incurs large memory overhead and extra copying from the tensor to its pinned buffer, but may still be faster than the former.
|
||||
3. Following HKUST research on the use of GPU direct, and their [GDR implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gdr/README.md), there is a smart way to benefit from the TensorFlow allocation theme which is mostly pool based, i.e allocators pre-allocate a large memory block, and allocate the tensors from there. By attaching a custom Visitor to relevant allocators, we can do a single registration of the entire memory block, which zeros the registration overhead. Once the block is registered, each new tensor allocated will be at a registered address, which will allow us to do direct RDMA writes to it.
|
||||
|
||||
For best performance, we will adopt HKUST 0 copies approach in our solution. This means:
|
||||
|
||||
1. Tensor writes will be done directly from the source tensor to the **result** tensor, with no memory copies in between. This should be done for all DMAable tensors which are located either on CPU or on a RDMA compatible GPU device (GPU direct).
|
||||
2. Non DMAable tensors (CanMemCopy == false) will be serialized to a TensorProto on the sender side, RDMA written to a registered buffer on the receiver side, and then deserialized by the receiver.
|
||||
3. Tensors which are located on a non-RDMA-compatible GPU, will be RDMA written to a registered CPU **proxy** buffer on the receiver side, and then copied to GPU by the receiver.
|
||||
|
||||
## Design details
|
||||
|
||||
### Terminology
|
||||
|
||||
* **Sender** - The node which sends the tensor.
|
||||
* **Receiver** - The node which receives the tensor.
|
||||
* **Result tensor** - The destination tensor, allocated on its appropriate device.
|
||||
* **Proxy tensor** - A CPU allocated tensor, which will be used in the case where the result tensor cannot be RDMA written to directly (GPU direct is disabled or not available). The RDMA write will therefore be done to the proxy tensor, and afterwards we will do a manual local copy from it to the result tensor.
|
||||
|
||||
### Messages
|
||||
|
||||
* RDMA_MESSAGE_TENSOR_REQUEST
|
||||
* RDMA_MESSAGE_META_DATA_RESPONSE
|
||||
* RDMA_MESSAGE_TENSOR_RE_REQUEST
|
||||
|
||||
### Transport protocol
|
||||
|
||||
The tensor transfer process is initiated when the receiver requests a tensor. In code it is done by calling **Rendezvous::Recv()** or **Rendezvous::RecvAsync()**. The TensorFlow base implementation handles the case where the requested tensor is located on the same node. The more interesting case where the requested tensor is located on a remote node (receiver != sender) is to be handled in a derivation of the pure virtual **BaseRemoteRendezvous::RecvFromRemoteAsync()**. TensorFlow provides a default GRPC based implementation which comes in the vanilla version but suffers in scalability when running large models. Our RDMA based implementation presumes to be more scalable. HKUST's contrib GDR implementation is more scalable than GRPC, and less scalable than ours, only because we did our evolution based on it.
|
||||
|
||||
Our entry point is the implementation of **RdmaRemoteRendezvous::RecvFromRemoteAsync()**, located in rdma_rendezvous_mgr.cc. The implementation creates a new **RdmaTensorRequest** object, keyed by request index (uint32_t), stores it in a list of pending requests, and calls its **Start()** method. The **Start()** method basically does 2 things:
|
||||
|
||||
1. Allocate the result tensor (and the proxy tensor if required).
|
||||
2. Send a **RDMA_MESSAGE_TENSOR_REQUEST** to the sender, containing the address of the destination tensor (result/proxy) for RDMA write.
|
||||
|
||||
In order to allocate the result and proxy tensors, we need to know the tensor's meta-data, i.e. shape and data-type for DMAable tensors, and proto-size for serialized tensors. Unfortunately, this information is only available on the sender side which complicates manners. In order to avoid sending extra messages for querying the meta-data at each step, we store a local meta-data cache per tensor, which will only be update upon changes. Based on the assumption that the meta-data of a tensor rarely changes between steps, we expect that on most times the cache will only be updated once. The sender is responsible to detect changes in the meta-data, and update the receiver. In order for the sender to know that the meta-data had changed, each **RDMA_MESSAGE_TENSOR_REQUEST** will contain the meta-data that the receiver had grabbed from the local cache. The sender will then compare the meta-data from the message to the tensor's new meta-data.
|
||||
|
||||
When the sender receives an **RDMA_MESSAGE_TENSOR_REQUEST**, it will create a new **RdmaTensorResponse** object for the given request message, store it in a list of pending responses, and will invoke its **Start()** method. The **Start()** method does the following:
|
||||
|
||||
1. Grab the source tensor from the local table (In code, **RecvLocalAsync()**).
|
||||
2. If the source tensor is not DMAable, serialize it to a TensorProto.
|
||||
3. If the source tensor is located on a device which cannot be DMA written from, copy it to CPU.
|
||||
4. If it is the first time this tensor is requested, or if the tensor's meta-data changed:
|
||||
1. Clone the tensor's data to be sent later.
|
||||
2. Send a **RDMA_MESSAGE_META_DATA_RESPONSE** containing the new meta-data.
|
||||
5. Otherwise:
|
||||
1. RDMA write the tensor (or TensorProto) to the destination address and rkey specified in the request message. The immediate value for the write will be the request index.
|
||||
|
||||
|
||||
When the receiver receives the **RDMA_MESSAGE_META_DATA_RESPONSE**, it will locate the relevant **RdmaTensorRequest** using the request index specified in the message, and invoke its **RecvTensorMetaData()** which does the following:
|
||||
|
||||
1. Update the local meta-data cache.
|
||||
2. Reallocate the result/proxy tensors.
|
||||
3. Re-send the tensor request. For traceability, the new message has a different name: **RDMA_MESSAGE_TENSOR_RE_REQUEST**.
|
||||
|
||||
When the sender receives a **RDMA_MESSAGE_TENSOR_RE_REQUEST**, it will locate the relevant **RdmaTensorResponse** using the request index specified in the message, and invoke its **Resume()** method, which will RDMA write the contents of the tensor that was cloned earlier, to the new remote address specified in the re-request.
|
||||
|
||||
When the receiver receives the RDMA write, it will locate the relevant **RdmaTensorRequest** using the request index which is the immediate value. It will then invoke its **RecvTensorContent()** which does the following:
|
||||
|
||||
1. Proxy copy/deserialize if required.
|
||||
2. Invoke the done callback.
|
||||
3. Deallocate the result/proxy tensors and remove the request from the pending list.
|
||||
|
||||

|
||||
|
||||
### Additional design notes
|
||||
|
||||
1. When the sender receives a tensor request, the source tensor may or may not be ready yet. The situation is handled through a process of tag matching:
|
||||
* If the request arrives before the tensor is ready, then a callback is put in a local table, and will be invoked once the tensor arrives.
|
||||
* If the tensor is ready before the request arrives, than the tensor is put in a local table. When the request arrives, it will invoke the callback immediately.
|
||||
In code it is done by calling **RecvLocalAsync()**, which receives the tensor's key, step-id, and the callback.
|
||||
2. When the callback is invoked, the relevant tensor is removed from the tag matching table. In the case where we need to send the tensor's meta-data, the **RdmaTensorResponse** will store a copy of the tensor until the re-request arrives.
|
||||
3. The sending of protocol messages (**RDMA_MESSAGE_TENSOR_REQUEST**, **RDMA_MESSAGE_META_DATA_RESPONSE** and **RDMA_MESSAGE_TENSOR_RE_REQUEST**) is done by the class **RdmaMessageBuffer**. All messages are sent using RDMA writes from/to fixed messages buffers. This implies that we cannot send on a specific channel more than one message at a time. In order to synchronize the messages, the **RdmaMessageBuffer** holds the a local and remote buffer statuses which can be either busy or idle. When a write is issued, both statuses will be changed to busy. When the write-complete event is received, the local status is changed to idle. When the write is received on the remote side, the remote side will parse the message, and return an ACK back to the sending side on which the sending side will update the remote status to idle. When both the local and remote statuses are idle, the next message can be sent.
|
||||
5. ACK writes are empty writes (hence they require no buffer) with immediate value 0xFFFFFFFE. Message writes have the immediate value 0xFFFFFFFF. All other writes are tensor-content writes whose immediate value is the request-index.
|
||||
|
||||
### RDMA components
|
||||
|
||||
* **enum RdmaImmDataType** - Immediate types to distinguish between different RDMA writes on the remote side. Ack writes and control-message writes have a fixed immediate value. The rest of the writes are tensor writes and the immediate value is the relevant request index.
|
||||
* **enum RdmaWriteIDType** - Types to distinguish between different RDMA write-complete events: Ack, control message and tensor writes.
|
||||
* **class RdmaWriteID** - Context for RDMA write complete events. Holds the RdmaWriteIDType and additional data.
|
||||
* **class RdmaTensorMetaData** - Meta-data for a tensor (type, shape, is_dead, proto_size).
|
||||
* **class RdmaMemoryMgr** - Manages the meta-data cache, and the registered memory regions.
|
||||
* **class RdmaTensorRequest** - Holds and manages information for a single tensor request throughout the entire receive cycle. API:
|
||||
* **Start()** - Start the request sequence.
|
||||
* Allocate the result tensor (and proxy tensor if required).
|
||||
* Send RDMA_MESSAGE_TENSOR_REQUEST to the remote side.
|
||||
* **RecvTensorMetaData()** - Receive meta-data from the remote side.
|
||||
* Update the local meta-data cache.
|
||||
* Reallocate the result tensor (and proxy tensor if required).
|
||||
* Re-send the request to the remote side.
|
||||
* **RecvTensorContent()** - Receive tensor content from the remote side (RDMA write was completed).
|
||||
* Decode proto if required and/or move to GPU if the content was not written to it directly (GPU direct is not available).
|
||||
* Invoke the done callback.
|
||||
* **class RdmaTensorResponse** - Holds and manages information for a single tensor response throughout the entire send cycle. API:
|
||||
* **Start()** - Start the response sequence.
|
||||
* Find the tensor in the local tag-match table.
|
||||
* Compare the tensor's meta-data to the meta-data in the message (taken from the requester's local cache).
|
||||
* If meta-data changed:
|
||||
* Clone the tensor to be sent later.
|
||||
* Send a meta-data update message and wait for re-request.
|
||||
* Else:
|
||||
* Send the tensor's content (using direct RDMA write).
|
||||
* **Resume()** - Resume the response sequence after a re-request. Send the tensor's content that was cloned earlier.
|
||||
* **Destroy()** - Destroy the response's resources and remove it form the pending list.
|
||||
* **class RdmaAdapter** - The base for RDMA communications. It may contain multiple channels and buffers. It is responsible for handling various incoming RDMA messages.
|
||||
* **class RdmaChannel** - Responsible for RDMA connection to a particular node. It manages messagee buffers. A channel has a request table which stores all the pending tensor requests.
|
||||
* **class RdmaMessageBuffer** - Responsible for sending or receiving messages. It has a fixed size memory to store the data. It has a queue to store the pending jobs. A channel has two message buffers one for tx and one for rx.
|
||||
* **class RdmaMgr** - Manages the adapter and channels, including channel creation, channel setup via GRPC service, channel lookup, etc.
|
||||
* **class RdmaRendezvousMgr** - Manages multiple rdma rendezvous.
|
||||
* **class RdmaRemoteRendezvous** - A derived class of BaseRemoteRendezvous. This class is the back end for "send" and "recv" ops. When the sendrecv_op wants to send or receive a tensor, it calls the rendezvous' "send" and "recv" functions respectively. Rendezvous are identified by "step_id", a random number, so that tensors for different iterations don't get mixed up.
|
||||
|
||||
### Message structure:
|
||||
|
||||
| type | name_size | name | step_id | request_index | remote_addr/checksum | rkey | is_dead | data_type | tensor_shape | tensor_bytes | error_status |
|
||||
|------|---------- |------|---------|---------------|----------------------|------|---------|-----------|--------------|--------------|-----------------------|
|
||||
| 1B | 2B | 512 | 8B | 8B | 8B | 4B | 1B | XB | XB | 8B | Size - 4B, proto - XB |
|
||||
|
||||
* **RDMA_MESSAGE_TENSOR_REQUEST** - (receiver ==> sender) The original tensor request.
|
||||
* type - The message type.
|
||||
* name (name_size) - Name of the requested tensor.
|
||||
* step_id - Step ID.
|
||||
* request_index - Request index.
|
||||
* remote_addr/rkey - Address/rkey of the result/proxy tensor. Irrelevant for first-time request.
|
||||
* is_dead/data_type/tensor_shape/tensor_bytes - The current meta-data as stored in the receiver local cache. The sender will use that information to know if the receiver's cache requires updating.
|
||||
* **RDMA_MESSAGE_META_DATA_RESPONSE** - (sender ==> receiver) The meta-data update message in case meta-data had changed (or if it is the first time the tensor is requested).
|
||||
* type - The message type.
|
||||
* request_index - Request index.
|
||||
* is_dead/data_type/tensor_shape/tensor_bytes - The up-to-date meta-data.
|
||||
* checksum - In data validation mode, this will hold the checksum of the source tensor.
|
||||
* **RDMA_MESSAGE_TENSOR_RE_REQUEST** - (receiver ==> sender) Tensor re-request after meta-data update and reallocation of result/proxy tensors.
|
||||
* type - The message type.
|
||||
* name (name_size) - Name of the requested tensor.
|
||||
* step_id - Step ID.
|
||||
* request_index - Request index.
|
||||
* remote_addr/rkey - Address/rkey of the reallocated result/proxy tensor.
|
||||
* **RDMA_MESSAGE_ERROR_STATUS** - (sender ==> receiver) Notify the receiver that an error had occurred on the sender side, so it can propagate it to the upper levels.
|
||||
* type - The message type.
|
||||
* name (name_size) - Name of the requested tensor.
|
||||
* step_id - Step ID.
|
||||
* request_index - Request index.
|
||||
* error_status - The error status (code, message, details).
|
Binary file not shown.
Before Width: | Height: | Size: 13 KiB |
@ -1,47 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options,
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
SetDeadline(&ctx, call_options->GetTimeout());
|
||||
return FromGrpcStatus(stub_->GetRemoteAddress(&ctx, *request, response));
|
||||
}
|
||||
|
||||
Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
CallOptions call_options;
|
||||
call_options.SetTimeout(-1); // no time out
|
||||
return GetRemoteAddress(&call_options, request, response);
|
||||
}
|
||||
|
||||
void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
|
||||
int64 time_in_ms) {
|
||||
if (time_in_ms > 0) {
|
||||
ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,50 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_
|
||||
#define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// GrpcVerbsClient is a client that uses gRPC to talk to the Verbs service.
|
||||
class GrpcVerbsClient {
|
||||
public:
|
||||
explicit GrpcVerbsClient(SharedGrpcChannelPtr client_channel)
|
||||
: stub_(grpc::VerbsService::NewStub(client_channel)) {}
|
||||
~GrpcVerbsClient() {}
|
||||
|
||||
Status GetRemoteAddress(CallOptions* call_options,
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response);
|
||||
Status GetRemoteAddress(const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response);
|
||||
|
||||
private:
|
||||
std::unique_ptr<grpc::VerbsService::Stub> stub_;
|
||||
|
||||
void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsClient);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_
|
@ -1,164 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder)
|
||||
: is_shutdown_(false), worker_env_(worker_env) {
|
||||
builder->RegisterService(&verbs_service_);
|
||||
cq_ = builder->AddCompletionQueue().release();
|
||||
}
|
||||
|
||||
GrpcVerbsService::~GrpcVerbsService() {
|
||||
delete shutdown_alarm_;
|
||||
delete cq_;
|
||||
}
|
||||
|
||||
void GrpcVerbsService::Shutdown() {
|
||||
bool did_shutdown = false;
|
||||
{
|
||||
mutex_lock l(shutdown_mu_);
|
||||
if (!is_shutdown_) {
|
||||
LOG(INFO) << "Shutting down GrpcWorkerService.";
|
||||
is_shutdown_ = true;
|
||||
did_shutdown = true;
|
||||
}
|
||||
}
|
||||
if (did_shutdown) {
|
||||
shutdown_alarm_ =
|
||||
new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// This macro creates a new request for the given RPC method name
|
||||
// (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on
|
||||
// `this->cq_`.
|
||||
//
|
||||
// This macro is invoked one or more times for each RPC method to
|
||||
// ensure that there are sufficient completion queue entries to
|
||||
// handle incoming requests without blocking.
|
||||
//
|
||||
// The implementation of the request handler for each RPC method
|
||||
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
||||
// to keep accepting new requests.
|
||||
#define ENQUEUE_REQUEST(method, supports_cancel) \
|
||||
do { \
|
||||
mutex_lock l(shutdown_mu_); \
|
||||
if (!is_shutdown_) { \
|
||||
Call<GrpcVerbsService, grpc::VerbsService::AsyncService, \
|
||||
method##Request, method##Response>:: \
|
||||
EnqueueRequest(&verbs_service_, cq_, \
|
||||
&grpc::VerbsService::AsyncService::Request##method, \
|
||||
&GrpcVerbsService::method##Handler, \
|
||||
(supports_cancel)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// This method blocks forever handling requests from the completion queue.
|
||||
void GrpcVerbsService::HandleRPCsLoop() {
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
ENQUEUE_REQUEST(GetRemoteAddress, false);
|
||||
}
|
||||
|
||||
void* tag;
|
||||
bool ok;
|
||||
|
||||
while (cq_->Next(&tag, &ok)) {
|
||||
UntypedCall<GrpcVerbsService>::Tag* callback_tag =
|
||||
static_cast<UntypedCall<GrpcVerbsService>::Tag*>(tag);
|
||||
if (callback_tag) {
|
||||
callback_tag->OnCompleted(this, ok);
|
||||
} else {
|
||||
cq_->Shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GrpcVerbsService::GetRemoteAddressHandler(
|
||||
WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
|
||||
Status s = GetRemoteAddressSync(&call->request, &call->response);
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
ENQUEUE_REQUEST(GetRemoteAddress, false);
|
||||
}
|
||||
|
||||
// synchronous method
|
||||
Status GrpcVerbsService::GetRemoteAddressSync(
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
// analyzing request
|
||||
// the channel setting part is redundant.
|
||||
const string remote_host_name = request->host_name();
|
||||
RdmaChannel* rc = rdma_mgr_->FindChannel(remote_host_name);
|
||||
CHECK(rc);
|
||||
RdmaAddress ra;
|
||||
ra.lid = request->channel().lid();
|
||||
ra.qpn = request->channel().qpn();
|
||||
ra.psn = request->channel().psn();
|
||||
ra.snp = request->channel().snp();
|
||||
ra.iid = request->channel().iid();
|
||||
rc->SetRemoteAddress(ra, false);
|
||||
rc->Connect();
|
||||
int i = 0;
|
||||
int idx[] = {1, 0};
|
||||
std::vector<RdmaMessageBuffer*> mb(rc->message_buffers());
|
||||
CHECK_EQ(request->mr_size(), RdmaChannel::kNumMessageBuffers);
|
||||
for (const auto& mr : request->mr()) {
|
||||
// the connections are crossed, i.e.
|
||||
// local tx_message_buffer <---> remote rx_message_buffer_
|
||||
// local rx_message_buffer <---> remote tx_message_buffer_
|
||||
// hence idx[] = {1, 0}.
|
||||
RdmaMessageBuffer* rb = mb[idx[i]];
|
||||
RemoteMR rmr;
|
||||
rmr.remote_addr = mr.remote_addr();
|
||||
rmr.rkey = mr.rkey();
|
||||
rb->SetRemoteMR(rmr, false);
|
||||
i++;
|
||||
}
|
||||
CHECK(i == RdmaChannel::kNumMessageBuffers);
|
||||
|
||||
// setting up response
|
||||
response->set_host_name(
|
||||
worker_env_->session_mgr->LegacySession()->worker_name);
|
||||
Channel* channel_info = response->mutable_channel();
|
||||
channel_info->set_lid(rc->self().lid);
|
||||
channel_info->set_qpn(rc->self().qpn);
|
||||
channel_info->set_psn(rc->self().psn);
|
||||
channel_info->set_snp(rc->self().snp);
|
||||
channel_info->set_iid(rc->self().iid);
|
||||
for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
|
||||
MemoryRegion* mr = response->add_mr();
|
||||
mr->set_remote_addr(reinterpret_cast<uint64>(mb[i]->buffer()));
|
||||
mr->set_rkey(mb[i]->self()->rkey);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create a GrpcVerbsService, then assign it to a given handle.
|
||||
void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder) {
|
||||
*handle = new GrpcVerbsService(worker_env, builder);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
@ -1,69 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
|
||||
#define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "grpcpp/alarm.h"
|
||||
#include "grpcpp/grpcpp.h"
|
||||
#include "grpcpp/server_builder.h"
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GrpcVerbsService : public AsyncServiceInterface {
|
||||
public:
|
||||
GrpcVerbsService(const WorkerEnv* worker_env, ::grpc::ServerBuilder* builder);
|
||||
~GrpcVerbsService();
|
||||
void HandleRPCsLoop() override;
|
||||
void Shutdown() override;
|
||||
void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
|
||||
|
||||
private:
|
||||
template <class RequestMessage, class ResponseMessage>
|
||||
using WorkerCall = Call<GrpcVerbsService, grpc::VerbsService::AsyncService,
|
||||
RequestMessage, ResponseMessage>;
|
||||
void GetRemoteAddressHandler(
|
||||
WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call);
|
||||
Status GetRemoteAddressSync(const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response);
|
||||
|
||||
::grpc::ServerCompletionQueue* cq_;
|
||||
grpc::VerbsService::AsyncService verbs_service_;
|
||||
mutex shutdown_mu_;
|
||||
bool is_shutdown_ GUARDED_BY(shutdown_mu_);
|
||||
::grpc::Alarm* shutdown_alarm_;
|
||||
// not owned
|
||||
RdmaMgr* rdma_mgr_;
|
||||
const WorkerEnv* const worker_env_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsService);
|
||||
};
|
||||
|
||||
// Create a GrpcVerbsService, then assign it to a given handle.
|
||||
void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
|
@ -1,69 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
|
||||
|
||||
#include "grpcpp/impl/codegen/async_stream.h"
|
||||
#include "grpcpp/impl/codegen/async_unary_call.h"
|
||||
#include "grpcpp/impl/codegen/channel_interface.h"
|
||||
#include "grpcpp/impl/codegen/client_unary_call.h"
|
||||
#include "grpcpp/impl/codegen/method_handler_impl.h"
|
||||
#include "grpcpp/impl/codegen/rpc_service_method.h"
|
||||
#include "grpcpp/impl/codegen/service_type.h"
|
||||
#include "grpcpp/impl/codegen/sync_stream.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace grpc {
|
||||
|
||||
static const char* grpcVerbsService_method_names[] = {
|
||||
"/tensorflow.VerbsService/GetRemoteAddress",
|
||||
};
|
||||
|
||||
std::unique_ptr<VerbsService::Stub> VerbsService::NewStub(
|
||||
const std::shared_ptr< ::grpc::ChannelInterface>& channel,
|
||||
const ::grpc::StubOptions& options) {
|
||||
std::unique_ptr<VerbsService::Stub> stub(new VerbsService::Stub(channel));
|
||||
return stub;
|
||||
}
|
||||
|
||||
VerbsService::Stub::Stub(
|
||||
const std::shared_ptr< ::grpc::ChannelInterface>& channel)
|
||||
: channel_(channel),
|
||||
rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0],
|
||||
::grpc::internal::RpcMethod::NORMAL_RPC,
|
||||
channel) {}
|
||||
|
||||
::grpc::Status VerbsService::Stub::GetRemoteAddress(
|
||||
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
return ::grpc::internal::BlockingUnaryCall(
|
||||
channel_.get(), rpcmethod_GetRemoteAddress_, context, request, response);
|
||||
}
|
||||
|
||||
VerbsService::AsyncService::AsyncService() {
|
||||
for (int i = 0; i < 1; ++i) {
|
||||
AddMethod(new ::grpc::internal::RpcServiceMethod(
|
||||
grpcVerbsService_method_names[i],
|
||||
::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
|
||||
::grpc::Service::MarkMethodAsync(i);
|
||||
}
|
||||
}
|
||||
|
||||
VerbsService::AsyncService::~AsyncService() {}
|
||||
|
||||
} // namespace grpc
|
||||
|
||||
} // namespace tensorflow
|
@ -1,81 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_IMPL_H_
|
||||
#define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_IMPL_H_
|
||||
|
||||
#include "grpcpp/impl/codegen/async_stream.h"
|
||||
#include "grpcpp/impl/codegen/async_unary_call.h"
|
||||
#include "grpcpp/impl/codegen/proto_utils.h"
|
||||
#include "grpcpp/impl/codegen/rpc_method.h"
|
||||
#include "grpcpp/impl/codegen/service_type.h"
|
||||
#include "grpcpp/impl/codegen/status.h"
|
||||
#include "grpcpp/impl/codegen/stub_options.h"
|
||||
#include "grpcpp/impl/codegen/sync_stream.h"
|
||||
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace grpc {
|
||||
|
||||
// Implementation of `tensorflow.VerbsService`, based on the
|
||||
// definition in "//tensorflow/contrib/verbs/verbs_service.proto",
|
||||
// and the gRPC generated stub and service classes.
|
||||
// See the proto file for the definition of methods and messages.
|
||||
class VerbsService GRPC_FINAL {
|
||||
public:
|
||||
class StubInterface {
|
||||
public:
|
||||
virtual ~StubInterface() {}
|
||||
virtual ::grpc::Status GetRemoteAddress(
|
||||
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
|
||||
GetRemoteAddressResponse* response) = 0;
|
||||
};
|
||||
class Stub GRPC_FINAL : public StubInterface {
|
||||
public:
|
||||
Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel);
|
||||
::grpc::Status GetRemoteAddress(
|
||||
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
|
||||
GetRemoteAddressResponse* response) GRPC_OVERRIDE;
|
||||
|
||||
private:
|
||||
std::shared_ptr< ::grpc::ChannelInterface> channel_;
|
||||
const ::grpc::internal::RpcMethod rpcmethod_GetRemoteAddress_;
|
||||
};
|
||||
static std::unique_ptr<Stub> NewStub(
|
||||
const std::shared_ptr< ::grpc::ChannelInterface>& channel,
|
||||
const ::grpc::StubOptions& options = ::grpc::StubOptions());
|
||||
|
||||
class AsyncService : public ::grpc::Service {
|
||||
public:
|
||||
AsyncService();
|
||||
virtual ~AsyncService();
|
||||
void RequestGetRemoteAddress(
|
||||
::grpc::ServerContext* context, GetRemoteAddressRequest* request,
|
||||
::grpc::ServerAsyncResponseWriter<GetRemoteAddressResponse>* response,
|
||||
::grpc::CompletionQueue* new_call_cq,
|
||||
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
|
||||
::grpc::Service::RequestAsyncUnary(0, context, request, response,
|
||||
new_call_cq, notification_cq, tag);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace grpc
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_IMPL_H_
|
@ -1,87 +0,0 @@
|
||||
## Verbs implementation to use direct tensor writes (0 copies)
|
||||
|
||||
### Motivation:
|
||||
|
||||
Following HKUST research on the use of GPU direct, and their [GDR implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gdr/README.md), we wish to adopt the 0 copies approach and apply it to the current verbs implementation, while keeping the current implementation advantages, such as configurability and the use of RDMA for control messages.
|
||||
|
||||
### Performance:
|
||||
|
||||
Compared with the current GRPC, verbs and GDR implementation, the result implementation gave the best performance for every model, with any number of nodes. For VGG16 on 8 nodes with 4 P100 GPUs each, the prototype beat the second place by over 15%.
|
||||
|
||||
### Implementation requirements:
|
||||
|
||||
1. Tensor writes need to be done directly from the source Tensor to the destination Tensor, with no memory copies in between. This should be done for all DMAble tensors which are located either on CPU or on a RDMA compatible GPU device (GPU direct).
|
||||
2. Non DMAble tensors (CanMemCopy == false) will be serialized to proto on the sender side, RDMA written to a registered buffer on the receiver side, and then deserialized by the receiver.
|
||||
3. Tensors which are located on a non-RDMA-compatible GPU, will be RDMA written to a registered CPU proxy buffer on the receiver side, and then copied to GPU by the receiver.
|
||||
|
||||
### Implementation constrains:
|
||||
|
||||
For best stability and proof of correctness, we will divide the implementation to two stages:
|
||||
1. At first stage we will keep changes to the current implementation to the minimum possible. The expense will be that we may have unused or unnecessary code leftovers, which may also affect performance.
|
||||
2. At second stage, we will re-iterate over the code and remove irrelevant code parts.
|
||||
The design of the solution aims that we will achieve both stages with relative ease.
|
||||
|
||||
### Design guidelines:
|
||||
|
||||
1. Since we do not want to do any unnecessary memory copying, we will no longer allocate a fixed CPU buffer as the destination for the RDMA write. Instead we will do the writing directly to the result tensor, or if the result tensor is on a device which does not support RDMA, we will do the writing to a proxy CPU tensor and then copy its content to the result tensor.
|
||||
2. The address of the destination Tensor needs to be sent to the sender side for writing, meaning that the result/proxy tensor should be pre-allocated on the receiver side, prior to sending the tensor request. In order to do that, we need to know its meta-data, i.e. shape and data-type for DMAble tensors, and proto-size for serialized tensors. Unfortunately, this information is only available on the sender side which complicates manners. In order to avoid sending extra messages for querying the meta-data on each step, we store a local meta-data cache per tensor. Based on the assumption that the meta-data of a tensor rarely changes between steps, we expect that on most times the cache will only be updated once. When the sender receives a request for a tensor, if it is the first time this tensor is requested, or in the rare case that the meta-data did change, the sender will first send a meta-data response, on which the receiver will update the local cache, and reallocate the result/proxy tensors if required. When the receiver sends the tensor request, it will contain also the meta-data currently stored in its local cache, so the sender can compare it to see if there was a change.
|
||||
3. When the sender writes the tensor content to the result tensor, no additional data is being written with it. That means we need to reside on ibverbs immediate (uint32_t) to indicate which request we are responding to (in order to trigger the receive callback). The easiest and most elegant way is to key the recv callback with a unique request_index (uint32_t), instead of the current key_with_step_id (string).
|
||||
4. Since the sender no longer writes the tensor from/to fixed buffers, we no longer need to schedule the writes using the local/remote status. In addition we no longer rely on the RmdaTensorBuffer members as the source/destination addresses and rkey/lkey. Instead, each RdmaTensorBuffer will hold multiple "Response" objects (one per step-id), from which we derive destination address and rkey. The source address and lkey are always the ones of the source Tensor.
|
||||
5. With the addition of tensor pre-allocation, we noticed there is a large code similarity between sending the first tensor request and re-sending the request in case of meta-data changes. After implementing a common method for tensor pre-allocation, it turned out that implementation becomes much simpler by encapsulating the process of request sending/re-sending, meta-data response callback and content response callback, all in a single "Request" class. The request class holds all the relevant request information, which reduces excessive parameter passing and lambda capturing. This decision is purely for elegance and code simplicity, and we decided to implement it in first stage because it makes the implementation much easier.
|
||||
|
||||
### New types/classes:
|
||||
|
||||
* **enum RdmaImmDataType** - Immediate types to distinguish between different RDMA writes on the remote side. Ack writes and control-message writes have a fixed immediate value. The rest of the writes are tensor writes and the immediate value is the relevant request index.
|
||||
* **enum RdmaWriteIDType** - Types to distinguish between different RDMA write-complete events: Ack, control message, tensor DMA write and tensor proto write.
|
||||
* **class RdmaWriteID** - Context for RDMA write complete events. Holds the RdmaWriteIDType and additional data.
|
||||
* **class RemoteAddressContext** - Remote address information (address + mr). Will be passed as write context for tensor proto writes.
|
||||
* **class RdmaTensorMetaData** - Meta-data for a tensor (type, shape, is_dead, proto_size).
|
||||
* **class RdmaMemoryMgr** - Manages the meta-data cache, and the registered memory regions.
|
||||
* **class RdmaTensorRequest** - Holds and manages information for a single tensor request throughout the entire receive cycle. API:
|
||||
* Start() - Start the request.
|
||||
* RecvTensorMetaData() - Receive meta-data from the remote side.
|
||||
* RecvTensorContent() - Receive tensor content from the remote side and invoke the done() callback.
|
||||
* **class RdmaTensorResponse** - Holds information for a single tensor response, such as destination address and rkey.
|
||||
|
||||
### Protocol changes:
|
||||
|
||||
The protocol messages themselves will remain mostly unchanged at the first stage, but will be used differently, as described below. The current messages structures already have most of the required fields for the new implementation. The only change is the "buffer_size" field which is no longer used since we are no longer sending additional information with the tensor, and thus it is now always equal to the "tensor_bytes" field. Instead, we use that field to pass the "request_index".
|
||||
|
||||
### Message structure:
|
||||
|
||||
| type | name_size | name | step_id | request_index | remote_addr | rkey | is_dead | data_type | tensor_shape | tensor_bytes |
|
||||
|------|---------- |------|---------|---------------|-------------|------|---------|-----------|--------------|--------------|
|
||||
| 1B | 2B | 512 | 8B | 8B | 8B | 4B | 1B | XB | XB | 8B |
|
||||
|
||||
* **RDMA_MESSAGE_TENSOR_REQUEST** - (receiver ==> sender) The original tensor request.
|
||||
* type - The message type.
|
||||
* name (name_size) - Name of the requested tensor.
|
||||
* step_id - Step ID.
|
||||
* request_index - Request index.
|
||||
* remote_addr/rkey - Address/rkey of the result/proxy tensor. Irrelevant for first-time request.
|
||||
* is_dead/data_type/tensor_shape/tensor_bytes - The current meta-data as stored in the receiver local cache. The sender will use that information to know if the receiver's cache requires updating.
|
||||
* **RDMA_MESSAGE_BUFFER_REQUEST** - (sender ==> receiver) The meta-data update message in case meta-data had changed (or if it is the first time the tensor is requested).
|
||||
* type - The message type.
|
||||
* request_index - Request index.
|
||||
* is_dead/data_type/tensor_shape/tensor_bytes - The up-to-date meta-data.
|
||||
* **RDMA_MESSAGE_BUFFER_RESPONSE** - (receiver ==> sender) Tensor re-request after meta-data update and reallocation of result/proxy tensors.
|
||||
* type - The message type.
|
||||
* name (name_size) - Name of the requested tensor.
|
||||
* step_id - Step ID.
|
||||
* request_index - Request index.
|
||||
* remote_addr/rkey - Address/rkey of the reallocated result/proxy tensor.
|
||||
* is_dead/data_type/tensor_shape/tensor_bytes - The new meta-data. Will be removed in the next phase.
|
||||
* **RDMA_MESSAGE_TENSOR_WRITE** - (sender ==> receiver) No longer sent. There is only a direct write of the tensor content to the result/proxy tensor. Request index passed as the immediate value of the write.
|
||||
* **RDMA_MESSAGE_TENSOR_IDLE** - (receiver ==> sender) No longer sent.
|
||||
|
||||

|
||||
|
||||
### Second stage optimizations:
|
||||
1. Remove unused code leftovers.
|
||||
2. Remove the ACK buffer completely, since we can rely completely on its immediate value.
|
||||
|
||||
### Future optimizations:
|
||||
1. Map the tensor names to indexes, to significantly reduce the request message size.
|
||||
2. Understand the purpose of empty tensors and if we can skip remote fetching for them.
|
||||
3. Consider concatenating multiple requests and/or using multiple message buffers.
|
||||
4. Consider a no-request architecture.
|
File diff suppressed because it is too large
Load Diff
@ -1,527 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_H_
|
||||
#define TENSORFLOW_CONTRIB_VERBS_RDMA_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include <infiniband/verbs.h>
|
||||
#include <cstring> // for memset
|
||||
#include <functional>
|
||||
#include <memory> // for shared_ptr
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/verbs/verbs_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
#define PKEY_DEFAULT 0
|
||||
#define QUEUE_DEPTH_DEFAULT 1024
|
||||
#define TIMEOUT_DEFAULT 14
|
||||
#define RETRY_CNT_DEFAULT 7
|
||||
#define SL_DEFAULT 0
|
||||
#define TRAFFIC_CLASS 0
|
||||
|
||||
#define RDMA_LOG_0 LOG(INFO)
|
||||
#define RDMA_LOG_1 VLOG(1)
|
||||
#define RDMA_LOG_2 VLOG(2)
|
||||
#define RDMA_LOG(LEVEL) RDMA_LOG_##LEVEL
|
||||
|
||||
struct RdmaParams {
|
||||
uint8_t port_num;
|
||||
uint8_t sgid_index;
|
||||
uint8_t pkey_index;
|
||||
uint32_t queue_depth;
|
||||
uint8_t timeout;
|
||||
uint8_t retry_cnt;
|
||||
uint8_t sl;
|
||||
enum ibv_mtu mtu;
|
||||
uint8_t traffic_class;
|
||||
};
|
||||
// structure to save the address of remote channels.
|
||||
struct RdmaAddress {
|
||||
uint32_t lid;
|
||||
uint32_t qpn;
|
||||
uint32_t psn;
|
||||
uint64_t snp;
|
||||
uint64_t iid;
|
||||
};
|
||||
// structure to save information for remote memory regions.
|
||||
struct RemoteMR {
|
||||
uint64_t remote_addr;
|
||||
uint32_t rkey;
|
||||
};
|
||||
enum BufferStatus { none, idle, busy };
|
||||
enum Location { local, remote };
|
||||
|
||||
enum RdmaMessageType {
|
||||
RDMA_MESSAGE_META_DATA_UPDATE,
|
||||
RDMA_MESSAGE_TENSOR_RE_REQUEST,
|
||||
RDMA_MESSAGE_TENSOR_REQUEST,
|
||||
RDMA_MESSAGE_ERROR_STATUS,
|
||||
};
|
||||
|
||||
struct RdmaMessage {
|
||||
RdmaMessageType type_;
|
||||
uint16_t name_size_;
|
||||
string name_;
|
||||
int64 step_id_;
|
||||
uint64_t request_index_;
|
||||
union {
|
||||
uint64_t remote_addr_;
|
||||
#ifdef RDMA_DATA_VALIDATION
|
||||
uint64_t checksum_;
|
||||
#endif
|
||||
};
|
||||
uint32_t rkey_;
|
||||
bool is_dead_;
|
||||
DataType data_type_;
|
||||
TensorShape tensor_shape_;
|
||||
size_t tensor_bytes_;
|
||||
|
||||
// For error status:
|
||||
Status status_;
|
||||
|
||||
// type|name_size|name|step_id|request_index|remote_addr/checksum|rkey|...
|
||||
// 1B| 2B | 512| 8B | 8B | 8B | 4B |...
|
||||
// ...|is_dead|data_type|tensor_shape|tensor_bytes|error_status |
|
||||
// ...| 1B | XB | XB | 8B |size - 4B, proto - XB |
|
||||
static const size_t kNameCapacity = 512;
|
||||
static const size_t kTypeStartIndex = 0;
|
||||
static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_);
|
||||
static const size_t kNameStartIndex =
|
||||
kNameSizeStartIndex + sizeof(name_size_);
|
||||
static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity;
|
||||
static const size_t kRequestIndexStartIndex =
|
||||
kStepIdStartIndex + sizeof(step_id_);
|
||||
static const size_t kRemoteAddrStartIndex =
|
||||
kRequestIndexStartIndex + sizeof(request_index_);
|
||||
static const size_t kChecksumStartIndex = kRemoteAddrStartIndex;
|
||||
static const size_t kRkeyStartIndex =
|
||||
kRemoteAddrStartIndex + sizeof(remote_addr_);
|
||||
static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_);
|
||||
static const size_t kDataTypeStartIndex =
|
||||
kIsDeadStartIndex + sizeof(is_dead_);
|
||||
static const size_t kTensorShapeStartIndex =
|
||||
kDataTypeStartIndex + sizeof(data_type_);
|
||||
static const size_t kTensorBytesStartIndex =
|
||||
kTensorShapeStartIndex + sizeof(TensorShape);
|
||||
static const size_t kErrorStatusStartIndex =
|
||||
kTensorBytesStartIndex + sizeof(tensor_bytes_);
|
||||
static const size_t kErrorStatusMaxSize = 4096;
|
||||
|
||||
static const size_t kMessageTotalBytes = kErrorStatusStartIndex;
|
||||
static const size_t kRdmaMessageBufferSize =
|
||||
kMessageTotalBytes + kErrorStatusMaxSize;
|
||||
static string CreateMessage(const RdmaMessage& rm);
|
||||
static void ParseMessage(RdmaMessage& rm, void* buffer);
|
||||
};
|
||||
|
||||
// Immediate types for RDMA write
|
||||
enum RdmaImmDataType {
|
||||
RDMA_IMM_MAX_REQUEST_ID = 0xFFFFFFFD,
|
||||
RDMA_IMM_DATA_ACK = 0xFFFFFFFE,
|
||||
RDMA_IMM_DATA_MESSAGE = 0xFFFFFFFF
|
||||
};
|
||||
|
||||
// Write types for RDMA write-complete events
|
||||
enum RdmaWriteIDType {
|
||||
RDMA_WRITE_ID_ACK,
|
||||
RDMA_WRITE_ID_MESSAGE,
|
||||
RDMA_WRITE_ID_TENSOR_WRITE
|
||||
};
|
||||
|
||||
// Context for RDMA write-complete events
|
||||
class RdmaWriteID {
|
||||
public:
|
||||
RdmaWriteID(RdmaWriteIDType write_type, void* write_context)
|
||||
: write_type(write_type), write_context(write_context) {}
|
||||
|
||||
RdmaWriteIDType write_type;
|
||||
void* write_context;
|
||||
};
|
||||
|
||||
// Tensor meta-data
|
||||
class TensorMetaData {
|
||||
public:
|
||||
TensorShape tensor_shape_;
|
||||
DataType data_type_;
|
||||
size_t proto_size_;
|
||||
bool is_dead_;
|
||||
|
||||
std::ostream& print(std::ostream& out) const {
|
||||
out << "Dtype = " << DataTypeString(data_type_)
|
||||
<< ", Shape = " << tensor_shape_.DebugString() << ", Proto size = 0x"
|
||||
<< std::hex << proto_size_ << ", Is dead = " << is_dead_;
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out,
|
||||
const TensorMetaData& meta_data) {
|
||||
return meta_data.print(out);
|
||||
}
|
||||
|
||||
class RdmaChannel;
|
||||
|
||||
void MRDeleter(ibv_mr* mr);
|
||||
using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>;
|
||||
|
||||
// RdmaMemoryMgr
|
||||
// Manages the local meta-data cache, and the registered RDMA memory regions.
|
||||
class RdmaMemoryMgr {
|
||||
public:
|
||||
static RdmaMemoryMgr& Singleton() {
|
||||
static RdmaMemoryMgr instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Memory regions
|
||||
ibv_mr* FindMemoryRegion(void* addr, size_t length);
|
||||
void InsertMemoryRegion(void* addr, size_t length,
|
||||
const std::string& allocator_name);
|
||||
void EvictMemoryRegion(void* addr, size_t length);
|
||||
|
||||
// Tensor meta-data cache
|
||||
const TensorMetaData* GetTensorMetaData(const std::string& tensor_name);
|
||||
const TensorMetaData* SetTensorMetaData(const std::string& tensor_name,
|
||||
DataType dtype,
|
||||
const TensorShape& shape,
|
||||
bool is_dead, size_t proto_size);
|
||||
|
||||
struct ibv_pd* pd_;
|
||||
|
||||
protected:
|
||||
RdmaMemoryMgr() : pd_(nullptr) {}
|
||||
|
||||
static bool Comparator(const void* ptr, const MemoryRegionPtr& other) {
|
||||
return ptr < reinterpret_cast<char*>(other->addr) + other->length;
|
||||
}
|
||||
|
||||
private:
|
||||
mutex tensor_meta_data_mu_;
|
||||
std::unordered_map<std::string, TensorMetaData> tensors_meta_data_;
|
||||
|
||||
// Managed memory regions
|
||||
mutex mrs_mu_;
|
||||
std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(mrs_mu_);
|
||||
};
|
||||
|
||||
// RdmaTensorRequest
|
||||
// Represents a single tensor request.
|
||||
class RdmaTensorRequest {
|
||||
public:
|
||||
typedef Rendezvous::DoneCallback RecvDoneCallback;
|
||||
|
||||
// Creates a tensor request identified by index.
|
||||
RdmaTensorRequest(uint32_t index, const string& key, int64 step_id,
|
||||
RdmaChannel* channel, Device* dst_dev,
|
||||
const Rendezvous::Args recv_args,
|
||||
const RecvDoneCallback& done);
|
||||
~RdmaTensorRequest();
|
||||
|
||||
// Request unique index.
|
||||
uint32_t index() { return index_; }
|
||||
|
||||
// Start the tensor request sequence.
|
||||
//
|
||||
// 1. Allocate the result tensor (and proxy tensor if required).
|
||||
// 2. Send RDMA_MESSAGE_TENSOR_REQUEST to the remote side.
|
||||
void Start();
|
||||
|
||||
// Receive tensor meta-data.
|
||||
//
|
||||
// 1. Update the local meta-data cache.
|
||||
// 2. Reallocate the result tensor (and proxy tensor if required).
|
||||
// 3. Re-send the request to the remote side.
|
||||
void RecvTensorMetaData(DataType dtype, TensorShape shape, bool is_dead,
|
||||
size_t proto_size);
|
||||
|
||||
// Receive tensor content (RDMA write was completed).
|
||||
//
|
||||
// Decode proto if required and/or move to GPU if the content was not
|
||||
// written to it directly (GPU direct is not available). Afterwards,
|
||||
// invoke Done().
|
||||
void RecvTensorContent();
|
||||
|
||||
// Receive error status (in case of a remote error).
|
||||
// Invoke Done() with the status code.
|
||||
void RecvErrorStatus(const Status& status);
|
||||
|
||||
#ifdef RDMA_DATA_VALIDATION
|
||||
// Receive tensor checksum
|
||||
//
|
||||
// For validation: Get and store the Tensor's expected checksum for the
|
||||
// current request. Compare the result Tensor's checksum with the stored
|
||||
// checksum right before invoking Done().
|
||||
void RecvTensorChecksum(uint64_t checksum) { checksum_ = checksum; }
|
||||
#endif
|
||||
|
||||
private:
|
||||
void Done(const Status& s);
|
||||
void Send(RdmaMessageType message_type);
|
||||
bool AllocateTensors();
|
||||
void AllocateTensorsAsync(StatusCallback done);
|
||||
void DeallocateTensors();
|
||||
|
||||
uint32_t index_;
|
||||
string key_;
|
||||
int64 step_id_;
|
||||
RdmaChannel* channel_;
|
||||
Device* dst_dev_;
|
||||
Rendezvous::Args recv_args_;
|
||||
const TensorMetaData* meta_data_;
|
||||
Tensor* result_tensor_;
|
||||
Tensor* proxy_tensor_;
|
||||
void* rdma_addr_;
|
||||
ibv_mr* mr_;
|
||||
RecvDoneCallback done_;
|
||||
#ifdef RDMA_DATA_VALIDATION
|
||||
uint64_t checksum_;
|
||||
#endif
|
||||
};
|
||||
|
||||
// RdmaTensorResponse
|
||||
// Represents a single tensor response.
|
||||
class RdmaTensorResponse {
|
||||
public:
|
||||
// Creates a response for request message.
|
||||
RdmaTensorResponse(RdmaChannel* channel, const RdmaMessage& rm)
|
||||
: channel_(channel), rm_(rm) {}
|
||||
|
||||
void Update(const RdmaMessage& rm) { rm_ = rm; }
|
||||
|
||||
// Start the tensor response sequence.
|
||||
//
|
||||
// 1. Find the tensor in the local tag-match table and invoke RecvHandler.
|
||||
// (Using RecvLocalAsync()).
|
||||
// 2. Compare the tensor's meta-data to the meta-data in the message (taken
|
||||
// from the requester's local cache).
|
||||
// If meta-data changed:
|
||||
// a. Clone the tensor to be sent later.
|
||||
// b. Send a meta-data update message and wait for re-request.
|
||||
// Else:
|
||||
// a. Send the tensor's content (using direct RDMA write).
|
||||
void Start();
|
||||
|
||||
// Resume the response sequence, after a re-request.
|
||||
//
|
||||
// 1. Send the tensor's content that was cloned earlier.
|
||||
void Resume();
|
||||
|
||||
// Destroy the response's resources and remove it from the pending list.
|
||||
void Destroy();
|
||||
|
||||
private:
|
||||
void RecvHandler(Rendezvous::ParsedKey parsed,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in,
|
||||
bool is_dead);
|
||||
void Clone(const Tensor& in, const TensorProto& proto, bool is_dead);
|
||||
void Send(const Tensor& in, const TensorProto& proto, bool is_dead,
|
||||
const Status& status);
|
||||
bool TensorMetaDataChanged(const Tensor& in, bool is_dead);
|
||||
Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
|
||||
Device** src_dev);
|
||||
void SendMetaData(const Tensor& in, const TensorProto& proto, bool is_dead);
|
||||
void SendContent(const Tensor& in, const TensorProto& proto, bool is_dead);
|
||||
void SendErrorStatus(const Status& status);
|
||||
|
||||
RdmaChannel* channel_;
|
||||
RdmaMessage rm_; // The request message
|
||||
Device* src_dev_ = nullptr;
|
||||
TensorBuffer* src_buffer_ = nullptr;
|
||||
void* src_addr_ = nullptr;
|
||||
ibv_mr* mr_ = nullptr;
|
||||
uint64_t checksum_ = 0;
|
||||
bool meta_data_changed_ = false;
|
||||
|
||||
// Re-item:
|
||||
TensorProto* proto_ = nullptr;
|
||||
Tensor* tensor_ = nullptr;
|
||||
bool is_dead_ = false;
|
||||
};
|
||||
|
||||
class RdmaMessageBuffer;
|
||||
// Class that represents the Rdma Adapter.
|
||||
// Responsible for creation of the completion queue, and handling
|
||||
// of work completions.
|
||||
class RdmaAdapter {
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaMessageBuffer;
|
||||
friend class RdmaTensorResponse;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
|
||||
public:
|
||||
RdmaAdapter(const WorkerEnv* worker_env);
|
||||
~RdmaAdapter();
|
||||
// Adapter name, e.g. mlx5_0.
|
||||
string name() const;
|
||||
void StartPolling();
|
||||
void Process_CQ();
|
||||
|
||||
protected:
|
||||
static const int MAX_CONCURRENT_WRITES = 1000;
|
||||
ibv_context* context_;
|
||||
// RDMA configuration parameters
|
||||
RdmaParams params_;
|
||||
// ibverbs protection domain
|
||||
ibv_pd* pd_;
|
||||
// Completion event channel, to wait for work completions
|
||||
ibv_comp_channel* event_channel_;
|
||||
// Completion queue, to poll on work completions
|
||||
ibv_cq* cq_;
|
||||
// Pre-allocated work completions array used for polling
|
||||
ibv_wc wc_[MAX_CONCURRENT_WRITES * 2];
|
||||
// worker env for thread
|
||||
const WorkerEnv* worker_env_;
|
||||
// thread for cq.
|
||||
std::unique_ptr<Thread> polling_thread_;
|
||||
};
|
||||
|
||||
// Class that represents a connection to a remote Rdma peer.
|
||||
// Responsible for connecting queue pairs.
|
||||
class RdmaChannel {
|
||||
friend class RdmaAdapter;
|
||||
friend class RdmaMessageBuffer;
|
||||
friend class RdmaTensorBuffer;
|
||||
friend class RdmaTensorRequest;
|
||||
friend class RdmaTensorResponse;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
|
||||
public:
|
||||
explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name,
|
||||
const string remote_name_);
|
||||
~RdmaChannel();
|
||||
inline const RdmaAddress& self() { return self_; }
|
||||
RdmaAddress address() const;
|
||||
inline const std::vector<RdmaMessageBuffer*>& message_buffers() const {
|
||||
return message_buffers_;
|
||||
}
|
||||
void Connect(const RdmaAddress& remoteAddr);
|
||||
void Connect();
|
||||
void Recv();
|
||||
void SetRemoteAddress(const RdmaAddress& ra, bool override);
|
||||
|
||||
// Requests:
|
||||
RdmaTensorRequest* InsertTensorRequest(
|
||||
const string& key, int64 step_id, Device* dst_dev,
|
||||
const Rendezvous::Args recv_args,
|
||||
const RdmaTensorRequest::RecvDoneCallback& done);
|
||||
void RemoveTensorRequest(uint32_t request_index);
|
||||
RdmaTensorRequest* GetTensorRequest(uint32_t request_index);
|
||||
|
||||
// Responses:
|
||||
RdmaTensorResponse* AddTensorResponse(const RdmaMessage& rm);
|
||||
RdmaTensorResponse* UpdateTensorResponse(const RdmaMessage& rm);
|
||||
void RemoveTensorResponse(uint32_t request_index);
|
||||
|
||||
static const int kNumMessageBuffers = 2;
|
||||
static const int kPingRecvWrid = 0;
|
||||
|
||||
private:
|
||||
static const int kPingBuffSize = 1024;
|
||||
char ping_buff_[kPingBuffSize];
|
||||
struct ibv_mr* mr_;
|
||||
struct ibv_sge ping_sge_list_;
|
||||
int PingPostRecv();
|
||||
int PingPostSend();
|
||||
|
||||
protected:
|
||||
const RdmaAdapter* adapter_;
|
||||
RdmaAddress self_;
|
||||
string local_name_;
|
||||
string remote_name_;
|
||||
ibv_qp* qp_;
|
||||
mutex mu_;
|
||||
bool connected_ GUARDED_BY(mu_) = false;
|
||||
RdmaAddress remote_ GUARDED_BY(mu_);
|
||||
bool remote_set_ GUARDED_BY(mu_) = false;
|
||||
mutex ct_mu_;
|
||||
typedef std::unordered_map<uint32_t, RdmaTensorRequest> RequestTable;
|
||||
RequestTable request_table_ GUARDED_BY(ct_mu_);
|
||||
uint32_t request_serial_ GUARDED_BY(ct_mu_);
|
||||
mutex responses_mu_;
|
||||
typedef std::unordered_map<uint32_t, RdmaTensorResponse> ResponsesTable;
|
||||
ResponsesTable responses_table_ GUARDED_BY(responses_mu_);
|
||||
RdmaMessageBuffer* tx_message_buffer_;
|
||||
RdmaMessageBuffer* rx_message_buffer_;
|
||||
std::vector<RdmaMessageBuffer*> message_buffers_;
|
||||
};
|
||||
|
||||
// Class that represents a buffer for Rdma message sending.
|
||||
class RdmaMessageBuffer {
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaAdapter;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
|
||||
public:
|
||||
explicit RdmaMessageBuffer(RdmaChannel* channel, string name);
|
||||
~RdmaMessageBuffer();
|
||||
|
||||
inline void* buffer() const { return buffer_; }
|
||||
inline ibv_mr* self() const { return self_; }
|
||||
inline void SetBufferStatus(Location loc, BufferStatus status) {
|
||||
mu_.lock();
|
||||
if (loc == local) {
|
||||
local_status_ = status;
|
||||
} else {
|
||||
remote_status_ = status;
|
||||
}
|
||||
mu_.unlock();
|
||||
}
|
||||
void FreeBuffer();
|
||||
void EnqueueItem(string Item);
|
||||
void SendNextItem();
|
||||
void CreateCPUBuffer(size_t size, bool lock = true);
|
||||
void SetRemoteMR(RemoteMR rmi, bool override);
|
||||
void Write(uint32_t imm_data, size_t buffer_size);
|
||||
static void Write(const RdmaChannel* channel, uint32_t imm_data,
|
||||
size_t buffer_size, uint64_t src_addr, uint32_t lkey,
|
||||
uint64_t remote_addr, uint32_t rkey,
|
||||
RdmaWriteIDType write_type, void* write_context);
|
||||
static void SendAck(const RdmaChannel* channel);
|
||||
|
||||
protected:
|
||||
const RdmaChannel* channel_;
|
||||
void* buffer_ = nullptr;
|
||||
bool buffer_on_host_ = true;
|
||||
size_t size_ = 0;
|
||||
const string name_;
|
||||
ibv_mr* self_ = nullptr;
|
||||
mutex mu_;
|
||||
RemoteMR remote_;
|
||||
std::queue<string> queue_ GUARDED_BY(mu_);
|
||||
BufferStatus local_status_ GUARDED_BY(mu_) = none;
|
||||
BufferStatus remote_status_ GUARDED_BY(mu_) = none;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // TENSORFLOW_CONTRIB_VERBS_RDMA_H_
|
@ -1,307 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#include "tensorflow/core/common_runtime/pool_allocator.h"
|
||||
#include "tensorflow/core/common_runtime/process_state.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||
#include "tensorflow/core/framework/allocator_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
|
||||
GrpcChannelCache* const channel_cache)
|
||||
: worker_env_(worker_env), channel_cache_(channel_cache) {
|
||||
rdma_adapter_ = new RdmaAdapter(worker_env_);
|
||||
// hardcoded to default session (legacy_session_)
|
||||
// TODO: use WorkerSessionForSession
|
||||
// need to pass in session handle
|
||||
local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name;
|
||||
std::vector<string> workers;
|
||||
worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers(
|
||||
&workers);
|
||||
num_remote_workers_ = workers.size() - 1;
|
||||
VLOG(2) << "rmda_mgr on local worker: " << local_worker_;
|
||||
for (size_t i = 0; i < workers.size(); i++) {
|
||||
if (local_worker_.compare(workers[i]) != 0) {
|
||||
channel_table_.insert(
|
||||
{workers[i],
|
||||
new RdmaChannel(rdma_adapter_, local_worker_, workers[i])});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup Rdma channels between peers.
|
||||
// This is done at the beginning of the server setup.
|
||||
|
||||
void RdmaMgr::SetupChannels() {
|
||||
for (const auto& p : channel_table_) {
|
||||
string worker_name = p.first;
|
||||
RDMA_LOG(2) << "Connecting to remote node " << worker_name;
|
||||
RdmaChannel* rc = p.second;
|
||||
GetRemoteAddressRequest req;
|
||||
GetRemoteAddressResponse resp;
|
||||
// get the channel cache
|
||||
SharedGrpcChannelPtr client_channel =
|
||||
channel_cache_->FindWorkerChannel(worker_name);
|
||||
GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
|
||||
CHECK(client != nullptr) << "No worker known as " << worker_name;
|
||||
|
||||
// setting up request
|
||||
req.set_host_name(local_worker_);
|
||||
Channel* channel_info = req.mutable_channel();
|
||||
channel_info->set_lid(rc->self_.lid);
|
||||
channel_info->set_qpn(rc->self_.qpn);
|
||||
channel_info->set_psn(rc->self_.psn);
|
||||
channel_info->set_snp(rc->self_.snp);
|
||||
channel_info->set_iid(rc->self_.iid);
|
||||
for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
|
||||
MemoryRegion* mr = req.add_mr();
|
||||
mr->set_remote_addr(
|
||||
reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
|
||||
mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
|
||||
}
|
||||
// synchronous call
|
||||
Status s;
|
||||
int attempts = 0;
|
||||
static const int max_num_attempts = 5;
|
||||
do {
|
||||
s = client->GetRemoteAddress(&req, &resp);
|
||||
// save obtained remote addresses
|
||||
// connect to the remote channel
|
||||
if (s.ok()) {
|
||||
CHECK(worker_name.compare(resp.host_name()) == 0);
|
||||
RdmaAddress ra;
|
||||
ra.lid = resp.channel().lid();
|
||||
ra.qpn = resp.channel().qpn();
|
||||
ra.psn = resp.channel().psn();
|
||||
ra.snp = resp.channel().snp();
|
||||
ra.iid = resp.channel().iid();
|
||||
rc->SetRemoteAddress(ra, false);
|
||||
rc->Connect();
|
||||
int i = 0;
|
||||
int idx[] = {1, 0};
|
||||
for (const auto& mr : resp.mr()) {
|
||||
// the connections are crossed, i.e.
|
||||
// local tx_message_buffer <---> remote rx_message_buffer_
|
||||
// local rx_message_buffer <---> remote tx_message_buffer_
|
||||
// hence idx[] = {1, 0}.
|
||||
RdmaMessageBuffer* rb = rc->message_buffers_[idx[i]];
|
||||
RemoteMR rmr;
|
||||
rmr.remote_addr = mr.remote_addr();
|
||||
rmr.rkey = mr.rkey();
|
||||
rb->SetRemoteMR(rmr, false);
|
||||
i++;
|
||||
}
|
||||
CHECK(i == RdmaChannel::kNumMessageBuffers);
|
||||
} else {
|
||||
LOG(ERROR) << "Connecting to " << worker_name << ": Got "
|
||||
<< s.error_message() << ". Retrying (" << (attempts + 1)
|
||||
<< "/" << max_num_attempts << ")...";
|
||||
if (++attempts == max_num_attempts) {
|
||||
break;
|
||||
}
|
||||
worker_env_->env->SleepForMicroseconds(2000000);
|
||||
}
|
||||
} while (!s.ok());
|
||||
RDMA_LOG(0) << "Connected to remote node " << worker_name;
|
||||
delete client;
|
||||
}
|
||||
}
|
||||
|
||||
// Check connectivity by pinging every channel
|
||||
bool RdmaMgr::ConnectivityCheck() {
|
||||
int i, rcnt = 0, scnt = 0;
|
||||
|
||||
for (const auto& p : channel_table_) {
|
||||
string worker_name = p.first;
|
||||
RdmaChannel* rc = p.second;
|
||||
|
||||
VLOG(2) << "Ping to " << worker_name;
|
||||
CHECK(rc->PingPostSend() == 0) << "Couldn't post send to " << worker_name
|
||||
<< " with error: " << std::strerror(errno);
|
||||
for (i = 0; i < rc->adapter_->params_.queue_depth - 1; i++) {
|
||||
rc->Recv();
|
||||
}
|
||||
}
|
||||
|
||||
while (rcnt < num_remote_workers_ || scnt < num_remote_workers_) {
|
||||
int ne;
|
||||
do {
|
||||
ne = ibv_poll_cq(rdma_adapter_->cq_, 2 * num_remote_workers_,
|
||||
rdma_adapter_->wc_);
|
||||
CHECK(ne >= 0) << "poll CQ failed " << ne << "with error"
|
||||
<< std::strerror(errno);
|
||||
} while (ne < 1);
|
||||
|
||||
for (i = 0; i < ne; ++i) {
|
||||
ibv_wc_status s = rdma_adapter_->wc_[i].status;
|
||||
// recv complete
|
||||
if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) {
|
||||
CHECK(s == IBV_WC_SUCCESS)
|
||||
<< ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
|
||||
<< rdma_adapter_->wc_[i].status << ") for PING_RECV_WRID";
|
||||
++rcnt;
|
||||
// send complete
|
||||
} else {
|
||||
RdmaChannel* rc =
|
||||
reinterpret_cast<RdmaChannel*>(rdma_adapter_->wc_[i].wr_id);
|
||||
CHECK(s == IBV_WC_SUCCESS)
|
||||
<< ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
|
||||
<< rdma_adapter_->wc_[i].status << ") to " << rc->remote_name_;
|
||||
++scnt;
|
||||
}
|
||||
} // for
|
||||
} // while
|
||||
CHECK(rcnt == scnt) << "Connectivity check failed!";
|
||||
rdma_adapter_->StartPolling();
|
||||
return (num_remote_workers_ == rcnt) && (num_remote_workers_ == scnt);
|
||||
}
|
||||
|
||||
RdmaMgr::~RdmaMgr() {
|
||||
for (const auto& p : channel_table_) delete p.second;
|
||||
channel_table_.clear();
|
||||
delete rdma_adapter_;
|
||||
}
|
||||
|
||||
// Find a channel via the given name.
|
||||
// Args:
|
||||
// name: peer name, e.g. worker1
|
||||
// Returns
|
||||
// channel object that is connected to the named peer.
|
||||
RdmaChannel* RdmaMgr::FindChannel(const string& name) {
|
||||
ChannelTable::iterator iter = channel_table_.find(name);
|
||||
CHECK(iter != channel_table_.end());
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
bool IsGDRAvailable() {
|
||||
#if defined(__APPLE__)
|
||||
return false;
|
||||
#elif defined(PLATFORM_WINDOWS)
|
||||
return false;
|
||||
#else
|
||||
std::ifstream ifs("/proc/modules");
|
||||
string line;
|
||||
while (std::getline(ifs, line)) {
|
||||
auto sep = line.find(' ');
|
||||
CHECK_NE(sep, std::string::npos);
|
||||
if (line.substr(0, sep) == "nv_peer_mem") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
int TryToReadNumaNode(ibv_device* device) {
|
||||
#if defined(__APPLE__)
|
||||
LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
|
||||
return 0;
|
||||
#elif defined(PLATFORM_WINDOWS)
|
||||
// Windows support for NUMA is not currently implemented. Return node 0.
|
||||
return 0;
|
||||
#else
|
||||
VLOG(2) << "Trying to read NUMA node for device: " << device->name;
|
||||
static const int kUnknownNumaNode = -1;
|
||||
|
||||
auto filename = string(device->ibdev_path) + "/device/numa_node";
|
||||
|
||||
std::ifstream ifs(filename.c_str());
|
||||
string content;
|
||||
CHECK(std::getline(ifs, content));
|
||||
|
||||
int32 value;
|
||||
if (strings::safe_strto32(content, &value)) {
|
||||
if (value < 0) {
|
||||
LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
|
||||
<< value
|
||||
<< "), but there must be at least one NUMA node"
|
||||
", so returning NUMA node zero";
|
||||
return 0;
|
||||
}
|
||||
LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
|
||||
return value;
|
||||
}
|
||||
return kUnknownNumaNode;
|
||||
#endif
|
||||
}
|
||||
|
||||
void MRDeleter(ibv_mr* mr) {
|
||||
if (mr) {
|
||||
ibv_dereg_mr(mr);
|
||||
}
|
||||
}
|
||||
|
||||
void RdmaMgr::InitAllocators() {
|
||||
static std::once_flag flag;
|
||||
std::call_once(
|
||||
flag, [this]() { RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_; });
|
||||
}
|
||||
|
||||
/*static*/ void RdmaMgr::RegMemVisitors() {
|
||||
SubAllocator::Visitor alloc_visitor = [](void* ptr, int numa_node,
|
||||
size_t num_bytes) {
|
||||
RdmaMemoryMgr::Singleton().InsertMemoryRegion(
|
||||
ptr, num_bytes, strings::StrCat("CPU:", numa_node));
|
||||
};
|
||||
SubAllocator::Visitor free_visitor = [](void* ptr, int numa_node,
|
||||
size_t num_bytes) {
|
||||
RdmaMemoryMgr::Singleton().EvictMemoryRegion(ptr, num_bytes);
|
||||
};
|
||||
|
||||
ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
|
||||
ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
GPUProcessState::singleton()->AddGpuHostAllocVisitor(0, alloc_visitor);
|
||||
GPUProcessState::singleton()->AddGpuHostFreeVisitor(0, free_visitor);
|
||||
|
||||
if (IsGDRAvailable()) {
|
||||
// Note we don't free allocated GPU memory so there is no free visitor
|
||||
|
||||
// TODO: This is to fix the 'invalid use of member in static member function
|
||||
// bug'.
|
||||
// Waiting for better implementation.
|
||||
// int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device)
|
||||
// + 1;
|
||||
int32_t bus_id = 0;
|
||||
|
||||
SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id,
|
||||
size_t num_bytes) {
|
||||
RdmaMemoryMgr::Singleton().InsertMemoryRegion(
|
||||
ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
|
||||
};
|
||||
GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
|
||||
cuda_alloc_visitor);
|
||||
LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
|
||||
}
|
||||
#endif // GOOGLE_CUDA
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif
|
@ -1,59 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
|
||||
#define TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RdmaMgr {
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaAdapter;
|
||||
|
||||
public:
|
||||
explicit RdmaMgr(const WorkerEnv* const worker_env,
|
||||
GrpcChannelCache* const channel_cache);
|
||||
~RdmaMgr();
|
||||
RdmaChannel* FindChannel(const string& key);
|
||||
void SetupChannels();
|
||||
bool ConnectivityCheck();
|
||||
void InitAllocators();
|
||||
static void RegMemVisitors();
|
||||
const string& local_worker() { return local_worker_; }
|
||||
|
||||
private:
|
||||
string local_worker_;
|
||||
size_t num_remote_workers_;
|
||||
const WorkerEnv* const worker_env_;
|
||||
GrpcChannelCache* const channel_cache_;
|
||||
RdmaAdapter* rdma_adapter_;
|
||||
typedef std::unordered_map<string, RdmaChannel*> ChannelTable;
|
||||
ChannelTable channel_table_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
|
@ -1,92 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/contrib/verbs/verbs_util.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
public:
|
||||
RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr)
|
||||
: BaseRemoteRendezvous(env, step_id), rdma_mgr_(rdma_mgr) {}
|
||||
|
||||
protected:
|
||||
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
private:
|
||||
~RdmaRemoteRendezvous() override {}
|
||||
RdmaMgr* rdma_mgr_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RdmaRemoteRendezvous);
|
||||
};
|
||||
|
||||
void RdmaRemoteRendezvous::RecvFromRemoteAsync(
|
||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
Status s;
|
||||
// parse src_name and dst_name
|
||||
string src_name, dst_name, unused;
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name,
|
||||
&unused) ||
|
||||
!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
|
||||
&unused)) {
|
||||
s = errors::Internal("Could not parse src or dst name.");
|
||||
}
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "s is not ok, error code " << s.error_message();
|
||||
done(s, Args(), recv_args, Tensor{}, false);
|
||||
return;
|
||||
}
|
||||
CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0);
|
||||
RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
|
||||
string key(parsed.FullKey());
|
||||
string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
|
||||
|
||||
Device* dst_dev;
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), true);
|
||||
return;
|
||||
}
|
||||
|
||||
RdmaTensorRequest* request =
|
||||
rc->InsertTensorRequest(key, step_id_, dst_dev, recv_args, done);
|
||||
request->Start();
|
||||
}
|
||||
|
||||
RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env)
|
||||
: BaseRendezvousMgr(env) {}
|
||||
|
||||
BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) {
|
||||
return new RdmaRemoteRendezvous(worker_env, step_id, rdma_mgr_);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif
|
@ -1,63 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
|
||||
#define TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// RendezvousMgr keeps track of a set of local rendezvous instances.
|
||||
// All tensors sent by this worker are buffered in a RendezvousMgr
|
||||
// until the tensor is received. Each global unique "step_id"
|
||||
// corresponds to one local rendezvous instance managed by a
|
||||
// RendezvousMgr.
|
||||
//
|
||||
// E.g.,
|
||||
// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
|
||||
// fork execution of an graph executor using "rendez" on thread 1;
|
||||
// fork execution of another graph executor using "rendez" on thread 2;
|
||||
// ...
|
||||
// join threads 1 and 2;
|
||||
//
|
||||
// In the example above, execution in thread 1 and 2 communicates with
|
||||
// each other by send/recv operations through the "rend".
|
||||
//
|
||||
// Tensors sent and recved through rendezvous managed by this
|
||||
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
|
||||
class RdmaRendezvousMgr : public BaseRendezvousMgr {
|
||||
public:
|
||||
explicit RdmaRendezvousMgr(const WorkerEnv* env);
|
||||
void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
|
||||
|
||||
protected:
|
||||
BaseRemoteRendezvous* Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) override;
|
||||
|
||||
private:
|
||||
RdmaMgr* rdma_mgr_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RdmaRendezvousMgr);
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
|
@ -1,176 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/verbs_server_lib.h"
|
||||
|
||||
#include "grpc/support/alloc.h"
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// static utility function
|
||||
RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env) {
|
||||
return new RdmaRendezvousMgr(env);
|
||||
}
|
||||
|
||||
std::once_flag reg_mem_visitors_call;
|
||||
|
||||
} // namespace
|
||||
|
||||
VerbsServer::VerbsServer(const ServerDef& server_def, Env* env)
|
||||
: GrpcServer(server_def, env), verbs_state_(DISCONNECTED) {}
|
||||
|
||||
VerbsServer::~VerbsServer() {
|
||||
TF_CHECK_OK(Stop());
|
||||
TF_CHECK_OK(Join());
|
||||
delete rdma_mgr_;
|
||||
delete verbs_service_;
|
||||
delete channel_cache_;
|
||||
}
|
||||
|
||||
Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
|
||||
GrpcChannelCache** channel_cache) {
|
||||
string name_prefix =
|
||||
strings::StrCat("/job:", server_def.job_name(), "/replica:0",
|
||||
"/task:", server_def.task_index());
|
||||
|
||||
GrpcChannelSpec channel_spec;
|
||||
TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
|
||||
|
||||
*channel_cache =
|
||||
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction());
|
||||
|
||||
const string host_port = (*channel_cache)->TranslateTask(name_prefix);
|
||||
int requested_port;
|
||||
|
||||
if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
|
||||
&requested_port)) {
|
||||
return errors::Internal("Could not parse port for local server from \"",
|
||||
(*channel_cache)->TranslateTask(name_prefix),
|
||||
"\".");
|
||||
}
|
||||
if (requested_port != bound_port()) {
|
||||
return errors::InvalidArgument("Requested port ", requested_port,
|
||||
" differs from expected port ",
|
||||
bound_port());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerbsServer::Init(ServiceInitFunction service_func,
|
||||
RendezvousMgrCreationFunction rendezvous_mgr_func) {
|
||||
std::call_once(reg_mem_visitors_call, []() { RdmaMgr::RegMemVisitors(); });
|
||||
GrpcServerOptions opts;
|
||||
opts.service_func = service_func;
|
||||
opts.rendezvous_mgr_func = rendezvous_mgr_func;
|
||||
Status s = GrpcServer::Init(opts);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
CHECK_EQ(verbs_state_, DISCONNECTED);
|
||||
CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok());
|
||||
rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_);
|
||||
// set rdma_mgr for verbs_service and rdma_rendezvous_mgr
|
||||
verbs_service_->SetRdmaMgr(rdma_mgr_);
|
||||
dynamic_cast<RdmaRendezvousMgr*>(worker_env()->rendezvous_mgr)
|
||||
->SetRdmaMgr(rdma_mgr_);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status VerbsServer::Start() {
|
||||
Status s = GrpcServer::Start();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (verbs_state_ == DISCONNECTED) {
|
||||
// verbs_thread needs to be initiated
|
||||
// before rdma_mgr sets up the rdma channels.
|
||||
verbs_thread_.reset(worker_env()->env->StartThread(
|
||||
ThreadOptions(), "TF_verbs_service",
|
||||
[this] { verbs_service_->HandleRPCsLoop(); }));
|
||||
rdma_mgr_->SetupChannels();
|
||||
CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!";
|
||||
rdma_mgr_->InitAllocators();
|
||||
verbs_state_ = CONNECTED;
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status VerbsServer::Join() {
|
||||
Status s = GrpcServer::Join();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (verbs_state_ == CONNECTED) {
|
||||
verbs_state_ = DISCONNECTED;
|
||||
verbs_thread_.reset();
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status VerbsServer::Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default()));
|
||||
ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder) {
|
||||
return SetNewVerbsService(&ret->verbs_service_, worker_env, builder);
|
||||
};
|
||||
TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr));
|
||||
*out_server = std::move(ret);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class VerbsServerFactory : public ServerFactory {
|
||||
public:
|
||||
bool AcceptsOptions(const ServerDef& server_def) override {
|
||||
return server_def.protocol() == "grpc+verbs";
|
||||
}
|
||||
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
std::unique_ptr<ServerInterface>* out_server) override {
|
||||
return VerbsServer::Create(server_def, Env::Default(), out_server);
|
||||
}
|
||||
};
|
||||
|
||||
// Registers a `ServerFactory` for `VerbsServer` instances.
|
||||
class VerbsServerRegistrar {
|
||||
public:
|
||||
VerbsServerRegistrar() {
|
||||
gpr_allocation_functions alloc_fns;
|
||||
alloc_fns.malloc_fn = port::Malloc;
|
||||
alloc_fns.realloc_fn = port::Realloc;
|
||||
alloc_fns.free_fn = port::Free;
|
||||
gpr_set_allocation_functions(alloc_fns);
|
||||
ServerFactory::Register("VERBS_SERVER", new VerbsServerFactory());
|
||||
}
|
||||
};
|
||||
static VerbsServerRegistrar registrar;
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif
|
@ -1,66 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
|
||||
#define TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class VerbsServer : public GrpcServer {
|
||||
protected:
|
||||
VerbsServer(const ServerDef& server_def, Env* env);
|
||||
|
||||
public:
|
||||
static Status Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<ServerInterface>* out_server);
|
||||
|
||||
// Destruction is only supported in the factory method. Clean
|
||||
// shutdown is not currently implemented for this server type.
|
||||
virtual ~VerbsServer() override;
|
||||
|
||||
// Implementations of ServerInterface methods.
|
||||
Status Start() override;
|
||||
Status Join() override;
|
||||
|
||||
protected:
|
||||
Status Init(ServiceInitFunction service_func,
|
||||
RendezvousMgrCreationFunction rendezvous_mgr_func);
|
||||
Status ChannelCacheFactory(const ServerDef& server_def,
|
||||
GrpcChannelCache** channel_cache);
|
||||
|
||||
private:
|
||||
RdmaMgr* rdma_mgr_;
|
||||
|
||||
// Guards state transitions.
|
||||
mutex mu_;
|
||||
|
||||
enum State { DISCONNECTED, CONNECTED };
|
||||
State verbs_state_ GUARDED_BY(mu_);
|
||||
|
||||
GrpcVerbsService* verbs_service_ = nullptr;
|
||||
std::unique_ptr<Thread> verbs_thread_ GUARDED_BY(mu_);
|
||||
GrpcChannelCache* channel_cache_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
|
@ -1,68 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option java_outer_classname = "VerbsServiceProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.contrib.verbs";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// GRPC Helper messages used to exchange RDMA information.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message Channel {
|
||||
int32 lid = 1;
|
||||
int32 qpn = 2;
|
||||
int32 psn = 3;
|
||||
uint64 snp = 4;
|
||||
uint64 iid = 5;
|
||||
}
|
||||
|
||||
message MemoryRegion {
|
||||
uint64 remote_addr = 1;
|
||||
uint32 rkey = 2;
|
||||
}
|
||||
message GetRemoteAddressRequest {
|
||||
string host_name = 1;
|
||||
Channel channel = 2;
|
||||
repeated MemoryRegion mr = 3;
|
||||
}
|
||||
|
||||
message GetRemoteAddressResponse {
|
||||
string host_name = 1;
|
||||
Channel channel = 2;
|
||||
repeated MemoryRegion mr = 3;
|
||||
}
|
||||
|
||||
message ErrorStatusProto {
|
||||
int32 error_code = 1;
|
||||
string error_message = 2;
|
||||
string error_details = 3;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// VerbsService
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
service VerbsService {
|
||||
rpc GetRemoteAddress(GetRemoteAddressRequest)
|
||||
returns (GetRemoteAddressResponse);
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/verbs/verbs_util.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// static
|
||||
string VerbsUtil::AppendStepidToKey(const string& key, int64 step_id) {
|
||||
return strings::StrCat(key, ";", step_id);
|
||||
}
|
||||
|
||||
// static
|
||||
void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key,
|
||||
int64& step_id) {
|
||||
StringPiece s(key_with_step_id);
|
||||
// a key (with step_id) has exact 6 parts if split by ";"
|
||||
// part 1: src_device;
|
||||
// part 2: src_incarnation;
|
||||
// part 3: dst_device;
|
||||
// part 4: name;
|
||||
// part 5: frame_iter.frame_id:frame_iter.iter_id
|
||||
// part 6: step_id
|
||||
std::vector<string> parts = str_util::Split(s, ';');
|
||||
CHECK(parts.size() == 6) << "Key with step_id must have 6 parts";
|
||||
strings::safe_strto64(parts[5], &step_id);
|
||||
parts.pop_back(); // remove step_id
|
||||
key.assign(absl::StrJoin(parts, ";")); // stitch them together
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,33 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_VERBS_VERBS_UTIL_H_
|
||||
#define TENSORFLOW_CONTRIB_VERBS_VERBS_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class VerbsUtil {
|
||||
public:
|
||||
static string AppendStepidToKey(const string& key, int64 step_id);
|
||||
static void GetKeyAndStepId(const string& key_with_step_id, string& key,
|
||||
int64& step_id);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CONTRIB_VERBS_VERBS_UTIL_H_
|
Binary file not shown.
Before Width: | Height: | Size: 61 KiB |
@ -1 +0,0 @@
|
||||
<mxfile userAgent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.84 Safari/537.36" version="7.8.7" editor="www.draw.io" type="device"><diagram name="Page-1" id="74e2e168-ea6b-b213-b513-2b3c1d86103e">7Vxtc9o4EP41zKQfmsGW3/hIgPQ60/RyIZ1rPzHClsFXY1FZEOivP8mW8ZsAB2yXtHQ6jb2SJXl3n0e7K6cdMFhsPhC4nD9gB/kdtetsOmDYUVVFUw32g0u2scRUu7FgRjxHdEoFY+8nEsKk28pzUJjrSDH2qbfMC20cBMimORkkBL/ku7nYz8+6hDNUEoxt6Jel/3oOnQup0u2mDX8hbzYXU1u6aJhC+/uM4FUg5uuowI3+xM0LmIwl+odz6OCXjAiMOmBAMKbx1WIzQD7XbaK2+Ln7Pa27dRMU0CoP6CB+Yg39FUqWHC2MbhNlRK+D+APdDrh7mXsUjZfQ5q0vzPxMNqcLn90p7NL1fH+AfUzYfYAD1ulOzIAIRZu9y1R2L8+cCuEFomTLumx2mo8fEf5kiduX1DhWIptn7GIkQigcYrYbOlUKuxB6kevIkqjI8NkMd463zqnK+LHihrtjL0rfQ9+bBR3QZz185NK0lV3NxM9olHAJg0Q2ppDQm3dJE1tatjUjjqbOS+tfTSKbEskK2lqYcsua+r6PbUgRJwIUhJiEt03OqfI5xyhwbp5Hn8d/P02eRv98GY2f37VhAam2c7PVBVDGTg5Elmtzu1OCv6NMi2FbaOru5iuBVQLp/fgFefwqRhnAalcC4B3jngPghGxrR/ATstfPkTs+IAqHkMKbC/GQ2oD3ZenEsGNah+/ZNeTbLrTnqHkAPiHYLuxlHAjKVCBjg8q0+XsBGahtAlkxjkcryGGRnLjFhM7xDAfQH6XSu7yWMxr9D1G6FcEoXFHMROkInzBein579RjiFbGTEFIsjW3oM5R002IZX+NBbRPkQ+qt89HoWZpTG6LAiCQGBMUgJShc4iBshRvs9SfGDX4/3AZ2s7QbUcAAL7dRGsKvH79wyCPisWd/8hf/6EZv/2PlEeTsffs3B3dT+aX7tv4r0M20RbZfszff+GC3Or/dePRr7u6bmKgaJ2hlTkhS5fo4QTz6iD22lJ0pe728KU2jYKF4UeKp1Eh9QuA2023JO4QH5jELO4RZyECP9E9SttRH4hWkHrPTSTXm00rMd++RkD57Cw7cjjngf1UDLjjggmm4LPJGjN4kwhvMYTBDDmMccF8V53O8mK7C4xh3GHvY1MOMjYbMbzjCLgP3LW/zvePbfDiHS37p+mjT5xWfCKyOuBzaPgxDzz5Ioa5lI9vOIr5bRnxJzVNL1/TKiLckQYBaEfAZXesSVSeyM3kBFCI6tcgL8euUeKE0kNbt3jK/MUxLUSxD10wjP65SjW9OgHjiHm35y5k+Id0FuhflFIbSu1WtHlAVy1KApqs5U2pFU1Z1kcPDgoo70ikeUi5zfkNhyUl4OJh3gbypRUFTUuMUMeTQZoZHTH7Hudbj4aloWHiOE8Unsi0gH7M0wd+gzN+axH3UOqK28ob7Gf++qraK8U6bqpblw00VQqJM75GFyfI0r61yMM/+5MFadgVPwwc+xAvxorz0Jq7SdaITI8qM/e61S3/zqZsh8cvmQjigH9+S28zlRMK2i+2q7tWqKdmrW8rYrKIFtWr74/6M7Zwd1CwZdFcaztDBm4eJ1mqFA1Q4fn0jmU6yn+WQYlZESjtRrVpIdTTzxDi2OJBe9IWaaimleZR6ayOgQj29EfeTlNbOdT+j7H7gstzvcPZjgkaSqtIXEPUlVaC8JdR9rDqIo7Xf6VRVUlqMI2uCN9soQOXnDPcOyh4veJWOF0oDyyLj6PTkY7BmYGMXDs+p+IGu7/NPl45Fxa9S0ZEz1aHkdJf7aeBE77rAayReGoX0tEzjzQfxxePWdoP4JN6UAHyuVIKAyNFlCEeMSEnGUHzEPXY6vVbAXMX2ghkT6Ondc5QevFf3GZ05HnH9aHebe46DgibKBoqp5yzbKxt299Fb1rDFHOAku6pN2ZV/J/EnW9U0jlq115RRZamEYO0iL7o4CiDsHWmldgSu2+1y8ihBdvjQnzyMxuP+h9Ek/1Vcxt7xyCUWnjbgBRdWBwRWnqoVyTeq0uiyjkKgVq65Nmf8h9FzfzLss3+eRuPHvz+PR1cHkDgA0Np0AFm9rXH0XwnggP21Vglg/0lALfYv1s+vFpdY3NDbtHhT2XfNSXUi+LhYxP2ruCF3Qv47M8VRBQO9yvsOJYiyVLLm9773kO+E+VfPLpBulyzPHaSp7sRjXrqJRQFciMaQouWE2V70XGCKJtBxiBB8R9v4in+mPYk/0z6SGZ9w6nUMuTwvNqaGbnTKNUDXVaMa4KVh2MhjWJV86QRkGbZeB4ab/tWihjAsg1m31PvPBbpUPweBfoXtebAFkmCrOdjKvk+8wvYPhO3r9ucrsk9Att7mhpyMcUV2ceqC818+viueVF1xtwd3hqR2XRfu2G36XxzEJ8/p/yMBRv8D</diagram></mxfile>
|
Binary file not shown.
Before Width: | Height: | Size: 87 KiB |
@ -1 +0,0 @@
|
||||
<mxfile userAgent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/62.0.3202.94 Safari/537.36" version="7.8.4" editor="www.draw.io" type="device"><diagram name="Page-1" id="74e2e168-ea6b-b213-b513-2b3c1d86103e">7Vxbc5s4FP41nuk+pMMd8ujYTrYzTZqNk9nmyaOAbNhgRIWc2P31K4G4yzZ1AHtaZzJjOBLS8TnnOzeRDNTRcn2DQejeIgf6A0Vy1gN1PFAUWVMM+sEom4RiGpcJYYE9h0/KCVPvJ+REiVNXngOj0kSCkE+8sEy0URBAm5RoAGP0Xp42R3551xAsYI0wtYFfp/7rOcTlVFmS8oG/obdw+daWzgdegP26wGgV8P0GijqPf5LhJUjX4vMjFzjovUBSJwN1hBEiydVyPYI+k20qtuS56y2jGd8YBqTJA1bywBvwVzDl2PDpo1eO98b4IxsuE+PHijF1ReCaXADfWwQDdUhn+HBO8lF6teCf8SpRCIKUNiUAk09/pUOUqeJogRxvXaa2z01Ke8ECDnpkDCxDehG8ROzjgs4c+j6yAYGPMIgQjkoC64WBKQycT4+Tu+m3h9nD5J+nyfSxax62a6K0m1LaRYlxBpklS3T43fUInIbAZqPv1C9RmkuWPr2Ts6ffIKZMbQWLnEGQujaIlpDgDZ3CH7A4aLlTkw1+/567CCX1EG7BO2RuA3C3tMiWzqFJLzg6xUhNPUbrUH2A9ltia7eQgDEgoHOTa6juNnZjBj2GoE9M1UHc/X4xZq+erq8nDLPT+29308nWTU8LRqrSJ4xkQwCjikCgQ5MBfoswcdECBcCf5NSrssgK4vkPErLh+QxYEURJ+QpfEQpLYmQb7RYi5QutsJ3O4rzSQLqA6TRNLGwMfUC8t/L6H5Kc0pEDivHiON/wU+hQyDzAKERBBLsHDfN8XylM/WG0Cezu9/syj/XyY+VhZjvDPgP7gKGsJRL7LiMUbm7unx7R6F6UXT2dUp7XqSCmEHuUsZ/wqN/45HMnQztq8qQfw8lT0eDN9+LNM1vss85u1x75Xrp75hsdGBq0emhIqvAPhAb+6D3y6M6ZKi+VsipNo6KhhAf+VK6kIcZgU5gWsglR831Us1LL7plvWLvnW9rO+fQi4Ti3sEyGzQKmVguY1x6OyKO3hMyZmCP2W/MyBbeQYDdNy0cuCBbQoX5GvW6KchctX1bRfoQ7NCTZxEPU2YypWTFEdoH6nnO9y/25XuSCkF3Ofbgess5RDFWHX45tH0SRZ5eFNfd8f4R8hOMl0gZPAe9SHe+HodoS5HuKWOIFieoCgaa0D2K/mrwrVewn3NewX1tI1aXPpiwZpiXLlqFrplFeV27mUw6AZWoEfVlFi/5cGhxR9bpx+VmxLlVFtixZ1XSlpDCtqrCmhrB7WbVhbDnEDtSaHTzDqGYKLA8rKzoiGL3CVNUBCmBF+5zEk7exTfUMKf2KuVKPlRt8YOk5TpxpiJxzOfvowherdV+sCcxHaSP/qofCO/T7itqSjihqUYOjq+KKFUD3KBSW7H2VQBfbUqgiAw/jW7bCO6bap5+fkr7cID5BIlSrv8r4iVVThsDAusurVH1/BO2zvOI1VJZwHRx0FVMQdLsposyqBrVmgW57EfWRUGjWFLqlIXdidq/12kVQ6xlDTSKnXU+kAaZk4KZY5P1klXKloNDMA/PI6kJ6VeMtdSVq+8jtdg3UBgcUnRiZoEl1oJEZdSNTj2pku2sMU+2kdEnbSR2ULmrdX7d9FjxK8qLf6ShY0Fo78FSmvyOWETtiubl/aqSHftgaw0h45LGbq/hJWqzte+TEEm3rmHl2muwIYO7KjYDA62ERziF1p7ggdbbiFqEfXpfTUsr2ggUl6PndY5zBXyjbNIgoY3M/jmQuLdth0E2JXlLsZUO9VrP0g9QqakC2olb2FsifrNRqedCrVkXFAQ9vVSc3R3EaYWfpWK5ImphJEuOxBtnx7XB2O5lOhzeTWfntvILCk5VrLvWlAzM4sZ5b1mNLD5gtgfJFOWYbTTet3t/sTvnZa15n5W9Tvqr1qXxRO6xz5Sfv+J21L9C+1iv0t/fbW9F+loLHK1T71tloVe1na8hydr1Pa+iqMm+54E4JX5bLZH4TE2UGyv6Spboq902/ZH27akDRAUzL3/vag74Tlb96kUGyCeFAGfHOAIzIzKNWuk5IAVjywYjAcEZ1z2cuEYEz4DiYE17hJrmidgRmDiBgb/F7wOHTPuRSzTnGi6EbA1EXULHtE8SwXMawInhvSBVl8nobGO76j6I6wrAIZlJt9p8LdKF8dgL9DNuPwVYVJGLdwVb0tt8Ztn8gbD8Un7e+kXvG/i9hX+8zZKdrnLFf3boCj9P3AA2PAs+424I7Q9D0bgt39Db/1wTJuXX+/x/Uyf8=</diagram></mxfile>
|
@ -125,7 +125,6 @@ load(
|
||||
"tf_additional_numa_deps",
|
||||
"tf_additional_numa_lib_defines",
|
||||
"tf_additional_test_deps",
|
||||
"tf_additional_verbs_lib_defines",
|
||||
"tf_grpc_service_all",
|
||||
"tf_jspb_proto_library",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
@ -2423,8 +2422,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = [
|
||||
LIB_INTERNAL_DEFINES = (
|
||||
tf_additional_lib_defines() + [
|
||||
"TF_USE_SNAPPY",
|
||||
] + tf_additional_verbs_lib_defines() +
|
||||
tf_additional_gdr_lib_defines() +
|
||||
] + tf_additional_gdr_lib_defines() +
|
||||
tf_additional_numa_lib_defines()
|
||||
)
|
||||
|
||||
|
@ -731,12 +731,6 @@ def tf_lib_proto_compiler_deps():
|
||||
"@com_google_protobuf//:protoc_lib",
|
||||
]
|
||||
|
||||
def tf_additional_verbs_lib_defines():
|
||||
return select({
|
||||
"//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tf_additional_gdr_lib_defines():
|
||||
return select({
|
||||
"//tensorflow:with_gdr_support": ["TENSORFLOW_USE_GDR"],
|
||||
|
@ -40,15 +40,6 @@ def tf_additional_license_deps():
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tf_additional_verbs_deps():
|
||||
return select({
|
||||
str(Label("//tensorflow:with_verbs_support")): [
|
||||
str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
|
||||
str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tf_additional_gdr_deps():
|
||||
return select({
|
||||
str(Label("//tensorflow:with_gdr_support")): [
|
||||
|
@ -26,7 +26,7 @@ load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
load("//tensorflow/core/platform:default/build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_cupti_test_flags", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused
|
||||
load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static", "tf_additional_gdr_deps", "tf_additional_plugin_deps", "tf_additional_verbs_deps")
|
||||
load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static", "tf_additional_gdr_deps", "tf_additional_plugin_deps")
|
||||
load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
|
||||
load(
|
||||
"//third_party/ngraph:build_defs.bzl",
|
||||
@ -5059,7 +5059,6 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/python/eager:pywrap_tfe_lib",
|
||||
] + (tf_additional_lib_deps() +
|
||||
tf_additional_plugin_deps() +
|
||||
tf_additional_verbs_deps() +
|
||||
tf_additional_gdr_deps()) + if_ngraph([
|
||||
"@ngraph_tf//:ngraph_tf",
|
||||
]),
|
||||
|
Loading…
Reference in New Issue
Block a user