[TF:XLA] Add gradient for collective_permute
operation
PiperOrigin-RevId: 218235356
This commit is contained in:
parent
52b9fab758
commit
d341198905
@ -137,6 +137,14 @@ if platform.system() != "Windows":
|
||||
"""
|
||||
return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
|
||||
|
||||
@ops.RegisterGradient("CollectivePermute")
|
||||
def _collective_permute_grad(op, grad):
|
||||
# The gradient of a collective permute operation is also a collective
|
||||
# permute, but with source/target pairs reversed. The gradient with respect
|
||||
# to input argument `source_target_pairs` is `None`.
|
||||
source_target_pairs = op.inputs[1][::-1, :]
|
||||
return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
|
||||
|
||||
@ops.RegisterGradient("CrossReplicaSum")
|
||||
def _cross_replica_sum_grad(op, grad):
|
||||
# The gradient of a cross replica sum is also a cross-replica sum.
|
||||
|
Loading…
Reference in New Issue
Block a user