Add try...finally block in _override_gradient_function

Without it errors in the user program may show up as an strange assertion
failure. (assert not self._gradient_function_map)

PiperOrigin-RevId: 351883333
Change-Id: I2cd3e40d682a638268f46415cdbe4683d11fefa6
This commit is contained in:
Ran Chen 2021-01-14 15:00:00 -08:00 committed by TensorFlower Gardener
parent 631fcba18c
commit 0c3d87c81e

View File

@ -4963,10 +4963,13 @@ class Graph(object):
"""Specify gradient function for the given op type."""
# This is an internal API and we don't need nested context for this.
# TODO(mdan): make it a proper context manager.
assert not self._gradient_function_map
self._gradient_function_map = gradient_function_map
yield
self._gradient_function_map = {}
try:
yield
finally:
self._gradient_function_map = {}
# pylint: disable=g-doc-return-or-yield
@tf_contextlib.contextmanager