From f50df6f0bed7ff0fe86464545515e6735ff8cf9e Mon Sep 17 00:00:00 2001 From: Yuxin Wu Date: Wed, 27 Jun 2018 13:41:26 -0700 Subject: [PATCH] Fix gradient of nccl_ops --- tensorflow/contrib/nccl/python/ops/nccl_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 029b01412d9..fa597cf3efc 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -63,12 +63,12 @@ def _all_sum_grad(op, grad): Raises: LookupError: If `reduction` is not `sum`. """ - if op.get_attr('reduction') != 'sum': + if op.get_attr('reduction') != b'sum': raise LookupError('No gradient defined for NcclAllReduce except sum.') _check_device(grad, expected=op.device) num_devices = op.get_attr('num_devices') - shared_name = op.get_attr('shared_name') + '_grad' + shared_name = op.get_attr('shared_name') + b'_grad' with ops.device(op.device): return gen_nccl_ops.nccl_all_reduce( @@ -162,7 +162,7 @@ def _reduce_sum_grad(op, grad): Raises: LookupError: If the reduction attribute of op is not `sum`. """ - if op.get_attr('reduction') != 'sum': + if op.get_attr('reduction') != b'sum': raise LookupError('No gradient defined for NcclReduce except sum.') _check_device(grad, expected=op.device)