Improve constant inference for Transpose gradient.

Currently Transpose op gradients don't compile on XLA when the Transpose is
inside a while loop.

PiperOrigin-RevId: 272927421
This commit is contained in:
A. Unique TensorFlower 2019-10-04 11:57:47 -07:00 committed by TensorFlower Gardener
parent 567af03528
commit 430666a7a6

View File

@ -701,6 +701,11 @@ def _SqueezeGrad(op, grad):
def _TransposeGrad(op, grad):
"""Returns unshuffle(grad)."""
p = op.inputs[1]
if not context.executing_eagerly():
p_static = pywrap_tensorflow.TF_TryEvaluateConstant_wrapper(
p.graph._c_graph, p._as_tf_output()) # pylint: disable=protected-access
if p_static is not None:
p = constant_op.constant(p_static, dtype=p.dtype)
return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]