Fix gradient of nccl_ops

This commit is contained in:
Yuxin Wu 2018-06-27 13:41:26 -07:00
parent 85cb8d48b7
commit f50df6f0be

View File

@ -63,12 +63,12 @@ def _all_sum_grad(op, grad):
Raises:
LookupError: If `reduction` is not `sum`.
"""
if op.get_attr('reduction') != 'sum':
if op.get_attr('reduction') != b'sum':
raise LookupError('No gradient defined for NcclAllReduce except sum.')
_check_device(grad, expected=op.device)
num_devices = op.get_attr('num_devices')
shared_name = op.get_attr('shared_name') + '_grad'
shared_name = op.get_attr('shared_name') + b'_grad'
with ops.device(op.device):
return gen_nccl_ops.nccl_all_reduce(
@ -162,7 +162,7 @@ def _reduce_sum_grad(op, grad):
Raises:
LookupError: If the reduction attribute of op is not `sum`.
"""
if op.get_attr('reduction') != 'sum':
if op.get_attr('reduction') != b'sum':
raise LookupError('No gradient defined for NcclReduce except sum.')
_check_device(grad, expected=op.device)