Merge pull request #20360 from ppwwyyxx/patch-13
Fix gradient of nccl_ops
This commit is contained in:
commit
1b27fba599
@ -63,12 +63,12 @@ def _all_sum_grad(op, grad):
|
|||||||
Raises:
|
Raises:
|
||||||
LookupError: If `reduction` is not `sum`.
|
LookupError: If `reduction` is not `sum`.
|
||||||
"""
|
"""
|
||||||
if op.get_attr('reduction') != 'sum':
|
if op.get_attr('reduction') != b'sum':
|
||||||
raise LookupError('No gradient defined for NcclAllReduce except sum.')
|
raise LookupError('No gradient defined for NcclAllReduce except sum.')
|
||||||
|
|
||||||
_check_device(grad, expected=op.device)
|
_check_device(grad, expected=op.device)
|
||||||
num_devices = op.get_attr('num_devices')
|
num_devices = op.get_attr('num_devices')
|
||||||
shared_name = op.get_attr('shared_name') + '_grad'
|
shared_name = op.get_attr('shared_name') + b'_grad'
|
||||||
|
|
||||||
with ops.device(op.device):
|
with ops.device(op.device):
|
||||||
return gen_nccl_ops.nccl_all_reduce(
|
return gen_nccl_ops.nccl_all_reduce(
|
||||||
@ -162,7 +162,7 @@ def _reduce_sum_grad(op, grad):
|
|||||||
Raises:
|
Raises:
|
||||||
LookupError: If the reduction attribute of op is not `sum`.
|
LookupError: If the reduction attribute of op is not `sum`.
|
||||||
"""
|
"""
|
||||||
if op.get_attr('reduction') != 'sum':
|
if op.get_attr('reduction') != b'sum':
|
||||||
raise LookupError('No gradient defined for NcclReduce except sum.')
|
raise LookupError('No gradient defined for NcclReduce except sum.')
|
||||||
_check_device(grad, expected=op.device)
|
_check_device(grad, expected=op.device)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user