From d34119890544633ee96e91bb87f2ff63d428028a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Oct 2018 14:34:23 -0700 Subject: [PATCH] [TF:XLA] Add gradient for `collective_permute` operation PiperOrigin-RevId: 218235356 --- tensorflow/contrib/tpu/python/ops/tpu_ops.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 968adccf2b8..261f70ba6b5 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -137,6 +137,14 @@ if platform.system() != "Windows": """ return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) + @ops.RegisterGradient("CollectivePermute") + def _collective_permute_grad(op, grad): + # The gradient of a collective permute operation is also a collective + # permute, but with source/target pairs reversed. The gradient with respect + # to input argument `source_target_pairs` is `None`. + source_target_pairs = op.inputs[1][::-1, :] + return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None] + @ops.RegisterGradient("CrossReplicaSum") def _cross_replica_sum_grad(op, grad): # The gradient of a cross replica sum is also a cross-replica sum.