Make nccl work in eager mode: wrap nccl ops in a defun; remove control dependencies on NcclAllReduce

PiperOrigin-RevId: 228801431
This commit is contained in:
Yuefeng Zhou 2019-01-10 16:43:49 -08:00 committed by TensorFlower Gardener
parent d2490b91e3
commit 40874676a6
3 changed files with 24 additions and 11 deletions

View File

@ -5869,6 +5869,8 @@ py_library(
deps = [
":framework_for_generated_wrappers",
":nccl_ops_gen",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
],
)

View File

@ -35,6 +35,7 @@ ASYNC_STATEFUL_OPS = [
"CollectiveReduce",
"CollectiveBcastSend",
"CollectiveBcastRecv",
"NcclAllReduce",
]

View File

@ -19,6 +19,8 @@ from __future__ import print_function
import threading
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import device
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_nccl_ops
@ -211,19 +213,27 @@ def _apply_all_reduce(reduction, tensors):
raise ValueError('Must pass >0 tensors to all reduce operations')
shared_name = _get_shared_name()
res = []
for t in tensors:
_check_device(t)
with ops.device(t.device):
res.append(
gen_nccl_ops.nccl_all_reduce(
input=t,
reduction=reduction,
num_devices=len(tensors),
shared_name=shared_name))
def _all_reduce():
"""Call nccl allreduce."""
res = []
for t in tensors:
_check_device(t)
with ops.device(t.device):
res.append(
gen_nccl_ops.nccl_all_reduce(
input=t,
reduction=reduction,
num_devices=len(tensors),
shared_name=shared_name))
return res
return res
if context.executing_eagerly():
# Nccl ops will block unless they are executed concurrently such as in a
# graph or a defun.
return def_function.function(_all_reduce)()
else:
return _all_reduce()
def _apply_reduce(reduction, tensors):