diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 5271b2d41e2..cf634701113 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -358,7 +358,8 @@ class OptimizerV2(trackable.Trackable): ValueError: In case any gradient cannot be computed (e.g. if gradient function not implemented). """ - grads = gradients.gradients(loss, params) + with backend.get_graph().as_default(): + grads = gradients.gradients(loss, params) if None in grads: raise ValueError("An operation has `None` for gradient. " "Please make sure that all of your ops have a " @@ -445,8 +446,7 @@ class OptimizerV2(trackable.Trackable): return apply_updates def get_updates(self, loss, params): - with backend.get_graph().as_default(): - grads = self.get_gradients(loss, params) + grads = self.get_gradients(loss, params) grads_and_vars = list(zip(grads, params)) self._assert_valid_dtypes([ v for g, v in grads_and_vars