Move backend.get_graph() inside get_gradients()
This commit is contained in:
parent
8ea85915b9
commit
9ac2380a80
@ -358,6 +358,7 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
ValueError: In case any gradient cannot be computed (e.g. if gradient
|
ValueError: In case any gradient cannot be computed (e.g. if gradient
|
||||||
function not implemented).
|
function not implemented).
|
||||||
"""
|
"""
|
||||||
|
with backend.get_graph().as_default():
|
||||||
grads = gradients.gradients(loss, params)
|
grads = gradients.gradients(loss, params)
|
||||||
if None in grads:
|
if None in grads:
|
||||||
raise ValueError("An operation has `None` for gradient. "
|
raise ValueError("An operation has `None` for gradient. "
|
||||||
@ -445,7 +446,6 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
return apply_updates
|
return apply_updates
|
||||||
|
|
||||||
def get_updates(self, loss, params):
|
def get_updates(self, loss, params):
|
||||||
with backend.get_graph().as_default():
|
|
||||||
grads = self.get_gradients(loss, params)
|
grads = self.get_gradients(loss, params)
|
||||||
grads_and_vars = list(zip(grads, params))
|
grads_and_vars = list(zip(grads, params))
|
||||||
self._assert_valid_dtypes([
|
self._assert_valid_dtypes([
|
||||||
|
Loading…
Reference in New Issue
Block a user