From 430666a7a6a62063f14a5844558c8a79392b95d9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Oct 2019 11:57:47 -0700 Subject: [PATCH] Improve constant inference for Transpose gradient. Currently Transpose op gradients don't compile on XLA when the Transpose is inside a while loop. PiperOrigin-RevId: 272927421 --- tensorflow/python/ops/array_grad.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 2c1bd445e54..1f027d3fe9f 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -701,6 +701,11 @@ def _SqueezeGrad(op, grad): def _TransposeGrad(op, grad): """Returns unshuffle(grad).""" p = op.inputs[1] + if not context.executing_eagerly(): + p_static = pywrap_tensorflow.TF_TryEvaluateConstant_wrapper( + p.graph._c_graph, p._as_tf_output()) # pylint: disable=protected-access + if p_static is not None: + p = constant_op.constant(p_static, dtype=p.dtype) return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]