Move backend.get_graph() inside get_gradients()
This commit is contained in:
parent
8ea85915b9
commit
9ac2380a80
@ -358,7 +358,8 @@ class OptimizerV2(trackable.Trackable):
|
||||
ValueError: In case any gradient cannot be computed (e.g. if gradient
|
||||
function not implemented).
|
||||
"""
|
||||
grads = gradients.gradients(loss, params)
|
||||
with backend.get_graph().as_default():
|
||||
grads = gradients.gradients(loss, params)
|
||||
if None in grads:
|
||||
raise ValueError("An operation has `None` for gradient. "
|
||||
"Please make sure that all of your ops have a "
|
||||
@ -445,8 +446,7 @@ class OptimizerV2(trackable.Trackable):
|
||||
return apply_updates
|
||||
|
||||
def get_updates(self, loss, params):
|
||||
with backend.get_graph().as_default():
|
||||
grads = self.get_gradients(loss, params)
|
||||
grads = self.get_gradients(loss, params)
|
||||
grads_and_vars = list(zip(grads, params))
|
||||
self._assert_valid_dtypes([
|
||||
v for g, v in grads_and_vars
|
||||
|
Loading…
Reference in New Issue
Block a user