From 40874676a61cc1823355016953a6ac6222f8b404 Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Thu, 10 Jan 2019 16:43:49 -0800 Subject: [PATCH] Make nccl work in eager mode: wrap nccl ops in a defun; remove control dependencies on NcclAllReduce PiperOrigin-RevId: 228801431 --- tensorflow/python/BUILD | 2 ++ .../python/framework/auto_control_deps.py | 1 + tensorflow/python/ops/nccl_ops.py | 32 ++++++++++++------- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c5a8994b80c..82c542f51d3 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5869,6 +5869,8 @@ py_library( deps = [ ":framework_for_generated_wrappers", ":nccl_ops_gen", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", ], ) diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index a7d61417bf6..da76a84e55e 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -35,6 +35,7 @@ ASYNC_STATEFUL_OPS = [ "CollectiveReduce", "CollectiveBcastSend", "CollectiveBcastRecv", + "NcclAllReduce", ] diff --git a/tensorflow/python/ops/nccl_ops.py b/tensorflow/python/ops/nccl_ops.py index 6259ce0f948..6c8685cf63a 100644 --- a/tensorflow/python/ops/nccl_ops.py +++ b/tensorflow/python/ops/nccl_ops.py @@ -19,6 +19,8 @@ from __future__ import print_function import threading +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import device from tensorflow.python.framework import ops from tensorflow.python.ops import gen_nccl_ops @@ -211,19 +213,27 @@ def _apply_all_reduce(reduction, tensors): raise ValueError('Must pass >0 tensors to all reduce operations') shared_name = _get_shared_name() - res = [] - for t in tensors: - _check_device(t) - with ops.device(t.device): - res.append( - gen_nccl_ops.nccl_all_reduce( - input=t, - reduction=reduction, - num_devices=len(tensors), - shared_name=shared_name)) + def _all_reduce(): + """Call nccl allreduce.""" + res = [] + for t in tensors: + _check_device(t) + with ops.device(t.device): + res.append( + gen_nccl_ops.nccl_all_reduce( + input=t, + reduction=reduction, + num_devices=len(tensors), + shared_name=shared_name)) + return res - return res + if context.executing_eagerly(): + # Nccl ops will block unless they are executed concurrently such as in a + # graph or a defun. + return def_function.function(_all_reduce)() + else: + return _all_reduce() def _apply_reduce(reduction, tensors):