Get XlaCompile attr errors out of all of our gradient stack traces

Python helpfully appends the error from the "try" portion to new errors from calling grad_fn in "except", but it's 100% irrelevant in this case.

PiperOrigin-RevId: 335689319
Change-Id: Ibd12257e585d7d28eb720e217aeff7a4b5915664
This commit is contained in:
Allen Lavoie 2020-10-06 12:00:28 -07:00 committed by TensorFlower Gardener
parent c41b1f7e72
commit 04ea37b5ed

View File

@ -333,7 +333,7 @@ def _MaybeCompile(scope, op, func, grad_fn):
"_XlaSeparateCompiledGradients")
xla_scope = op.get_attr("_XlaScope").decode()
except ValueError:
return grad_fn() # Exit early
xla_compile = False
if not xla_compile:
return grad_fn() # Exit early