diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py index 7a4ad538eca..b0554faef70 100644 --- a/tensorflow/python/training/tracking/util.py +++ b/tensorflow/python/training/tracking/util.py @@ -64,6 +64,14 @@ keras_backend = lazy_loader.LazyLoader( "tensorflow.python.keras.backend") +def get_session(): + # Prefer TF's default session since get_session from Keras has side-effects. + session = ops.get_default_session() + if session is None: + session = keras_backend.get_session() + return session + + class _ObjectGraphProtoPrettyPrinter(object): """Lazily traverses an object graph proto to pretty print names. @@ -613,7 +621,7 @@ def streaming_restore(status, session=None): # Streaming restore is the default/only behavior when executing eagerly. return if session is None: - session = keras_backend.get_session() + session = get_session() if isinstance(status, NameBasedSaverStatus): raise NotImplementedError( "Streaming restore not supported from name-based checkpoints when " @@ -756,7 +764,7 @@ class CheckpointLoadStatus(_LoadStatus): if context.executing_eagerly(): return # Run eagerly if session is None: - session = keras_backend.get_session() + session = get_session() session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) def initialize_or_restore(self, session=None): @@ -777,7 +785,7 @@ class CheckpointLoadStatus(_LoadStatus): if context.executing_eagerly(): return # Initialization and restoration ops are run eagerly if session is None: - session = keras_backend.get_session() + session = get_session() all_objects = self._graph_view.list_objects() already_initialized_objects = object_identity.ObjectIdentitySet( self._checkpoint.object_by_proto_id.values()) @@ -855,7 +863,7 @@ class InitializationOnlyStatus(_LoadStatus): if context.executing_eagerly(): return # run eagerly if session is None: - session = keras_backend.get_session() + session = get_session() trackable_objects = self._graph_view.list_objects() initializers = [ c.initializer for c in trackable_objects @@ -937,7 +945,7 @@ class NameBasedSaverStatus(_LoadStatus): if context.executing_eagerly(): return # Nothing to do, variables are restored on creation. if session is None: - session = keras_backend.get_session() + session = get_session() with ops.device("/cpu:0"): saveables = self._gather_saveable_objects() v1_saver_lib.Saver(saveables).restore( @@ -1109,7 +1117,7 @@ class TrackableSaver(object): if not use_session: session = None elif session is None: - session = keras_backend.get_session() + session = get_session() if session: return session.run(save_path, feed_dict=feed_dict) @@ -1492,7 +1500,7 @@ class CheckpointV1(tracking.AutoTrackable): "update metadata. tf.train.latest_checkpoint and related APIs will " "not see this checkpoint.") if session is None: - session = keras_backend.get_session() + session = get_session() if self._save_counter is None: # When graph building, if this is a new save counter variable then it # needs to be initialized before assign_add. This is only an issue if @@ -1822,7 +1830,7 @@ class Checkpoint(tracking.AutoTrackable): "tf.train.Checkpoint.write(), a lower-level API which does not " "update metadata. tf.train.latest_checkpoint and related APIs will " "not see this checkpoint.") - session = keras_backend.get_session() + session = get_session() if self._save_counter is None: # When graph building, if this is a new save counter variable then it # needs to be initialized before assign_add. This is only an issue if diff --git a/tensorflow/python/training/tracking/util_test.py b/tensorflow/python/training/tracking/util_test.py index 1ec5466e0ab..3969d259550 100644 --- a/tensorflow/python/training/tracking/util_test.py +++ b/tensorflow/python/training/tracking/util_test.py @@ -1412,7 +1412,9 @@ class TemplateTests(parameterized.TestCase, test.TestCase): optimizer.minimize(v1_save.read_value, var_list=[v1_save]) self.evaluate([v.initializer for v in save_template.variables]) - self.evaluate([v.initializer for v in optimizer.variables()]) + optimizer_variables = optimizer.variables() + list( + optimizer._hyper.values()) + self.evaluate([v.initializer for v in optimizer_variables]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir()