Make nccl work in eager mode: wrap nccl ops in a defun; remove control dependencies on NcclAllReduce
PiperOrigin-RevId: 228801431
This commit is contained in:
parent
d2490b91e3
commit
40874676a6
@ -5869,6 +5869,8 @@ py_library(
|
||||
deps = [
|
||||
":framework_for_generated_wrappers",
|
||||
":nccl_ops_gen",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -35,6 +35,7 @@ ASYNC_STATEFUL_OPS = [
|
||||
"CollectiveReduce",
|
||||
"CollectiveBcastSend",
|
||||
"CollectiveBcastRecv",
|
||||
"NcclAllReduce",
|
||||
]
|
||||
|
||||
|
||||
|
@ -19,6 +19,8 @@ from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import device
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_nccl_ops
|
||||
@ -211,19 +213,27 @@ def _apply_all_reduce(reduction, tensors):
|
||||
raise ValueError('Must pass >0 tensors to all reduce operations')
|
||||
|
||||
shared_name = _get_shared_name()
|
||||
res = []
|
||||
|
||||
for t in tensors:
|
||||
_check_device(t)
|
||||
with ops.device(t.device):
|
||||
res.append(
|
||||
gen_nccl_ops.nccl_all_reduce(
|
||||
input=t,
|
||||
reduction=reduction,
|
||||
num_devices=len(tensors),
|
||||
shared_name=shared_name))
|
||||
def _all_reduce():
|
||||
"""Call nccl allreduce."""
|
||||
res = []
|
||||
for t in tensors:
|
||||
_check_device(t)
|
||||
with ops.device(t.device):
|
||||
res.append(
|
||||
gen_nccl_ops.nccl_all_reduce(
|
||||
input=t,
|
||||
reduction=reduction,
|
||||
num_devices=len(tensors),
|
||||
shared_name=shared_name))
|
||||
return res
|
||||
|
||||
return res
|
||||
if context.executing_eagerly():
|
||||
# Nccl ops will block unless they are executed concurrently such as in a
|
||||
# graph or a defun.
|
||||
return def_function.function(_all_reduce)()
|
||||
else:
|
||||
return _all_reduce()
|
||||
|
||||
|
||||
def _apply_reduce(reduction, tensors):
|
||||
|
Loading…
Reference in New Issue
Block a user