Improve constant inference for Transpose gradient.
Currently Transpose op gradients don't compile on XLA when the Transpose is inside a while loop. PiperOrigin-RevId: 272927421
This commit is contained in:
parent
567af03528
commit
430666a7a6
@ -701,6 +701,11 @@ def _SqueezeGrad(op, grad):
|
||||
def _TransposeGrad(op, grad):
|
||||
"""Returns unshuffle(grad)."""
|
||||
p = op.inputs[1]
|
||||
if not context.executing_eagerly():
|
||||
p_static = pywrap_tensorflow.TF_TryEvaluateConstant_wrapper(
|
||||
p.graph._c_graph, p._as_tf_output()) # pylint: disable=protected-access
|
||||
if p_static is not None:
|
||||
p = constant_op.constant(p_static, dtype=p.dtype)
|
||||
return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user