[TF:XLA] Add gradient for collective_permute operation

PiperOrigin-RevId: 218235356
This commit is contained in:
A. Unique TensorFlower 2018-10-22 14:34:23 -07:00 committed by TensorFlower Gardener
parent 52b9fab758
commit d341198905

View File

@ -137,6 +137,14 @@ if platform.system() != "Windows":
"""
return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
@ops.RegisterGradient("CollectivePermute")
def _collective_permute_grad(op, grad):
# The gradient of a collective permute operation is also a collective
# permute, but with source/target pairs reversed. The gradient with respect
# to input argument `source_target_pairs` is `None`.
source_target_pairs = op.inputs[1][::-1, :]
return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
# The gradient of a cross replica sum is also a cross-replica sum.