Move backend.get_graph() inside get_gradients()

This commit is contained in:
Daniel Salvadori 2019-03-06 09:21:38 -03:00
parent 8ea85915b9
commit 9ac2380a80

View File

@ -358,6 +358,7 @@ class OptimizerV2(trackable.Trackable):
ValueError: In case any gradient cannot be computed (e.g. if gradient
function not implemented).
"""
with backend.get_graph().as_default():
grads = gradients.gradients(loss, params)
if None in grads:
raise ValueError("An operation has `None` for gradient. "
@ -445,7 +446,6 @@ class OptimizerV2(trackable.Trackable):
return apply_updates
def get_updates(self, loss, params):
with backend.get_graph().as_default():
grads = self.get_gradients(loss, params)
grads_and_vars = list(zip(grads, params))
self._assert_valid_dtypes([