Add try...finally block in _override_gradient_function
Without it errors in the user program may show up as an strange assertion failure. (assert not self._gradient_function_map) PiperOrigin-RevId: 351883333 Change-Id: I2cd3e40d682a638268f46415cdbe4683d11fefa6
This commit is contained in:
parent
631fcba18c
commit
0c3d87c81e
@ -4963,10 +4963,13 @@ class Graph(object):
|
||||
"""Specify gradient function for the given op type."""
|
||||
|
||||
# This is an internal API and we don't need nested context for this.
|
||||
# TODO(mdan): make it a proper context manager.
|
||||
assert not self._gradient_function_map
|
||||
self._gradient_function_map = gradient_function_map
|
||||
yield
|
||||
self._gradient_function_map = {}
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._gradient_function_map = {}
|
||||
|
||||
# pylint: disable=g-doc-return-or-yield
|
||||
@tf_contextlib.contextmanager
|
||||
|
Loading…
Reference in New Issue
Block a user