From eabe66e503eac2a253dfee441068e9ff662940d9 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 3 Jun 2019 18:47:25 -0700 Subject: [PATCH] Checkpointing: Don't use Keras get_session if we can avoid it since it has side effects Some people rely on their Keras session being used, but automatic initialization is causing device placement issues. So if there's a regular session registered we'll avoid the Keras initialization. PiperOrigin-RevId: 251353716 --- tensorflow/python/training/tracking/util.py | 24 ++++++++++++------- .../python/training/tracking/util_test.py | 4 +++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py index 7a4ad538eca..b0554faef70 100644 --- a/tensorflow/python/training/tracking/util.py +++ b/tensorflow/python/training/tracking/util.py @@ -64,6 +64,14 @@ keras_backend = lazy_loader.LazyLoader( "tensorflow.python.keras.backend") +def get_session(): + # Prefer TF's default session since get_session from Keras has side-effects. + session = ops.get_default_session() + if session is None: + session = keras_backend.get_session() + return session + + class _ObjectGraphProtoPrettyPrinter(object): """Lazily traverses an object graph proto to pretty print names. @@ -613,7 +621,7 @@ def streaming_restore(status, session=None): # Streaming restore is the default/only behavior when executing eagerly. return if session is None: - session = keras_backend.get_session() + session = get_session() if isinstance(status, NameBasedSaverStatus): raise NotImplementedError( "Streaming restore not supported from name-based checkpoints when " @@ -756,7 +764,7 @@ class CheckpointLoadStatus(_LoadStatus): if context.executing_eagerly(): return # Run eagerly if session is None: - session = keras_backend.get_session() + session = get_session() session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) def initialize_or_restore(self, session=None): @@ -777,7 +785,7 @@ class CheckpointLoadStatus(_LoadStatus): if context.executing_eagerly(): return # Initialization and restoration ops are run eagerly if session is None: - session = keras_backend.get_session() + session = get_session() all_objects = self._graph_view.list_objects() already_initialized_objects = object_identity.ObjectIdentitySet( self._checkpoint.object_by_proto_id.values()) @@ -855,7 +863,7 @@ class InitializationOnlyStatus(_LoadStatus): if context.executing_eagerly(): return # run eagerly if session is None: - session = keras_backend.get_session() + session = get_session() trackable_objects = self._graph_view.list_objects() initializers = [ c.initializer for c in trackable_objects @@ -937,7 +945,7 @@ class NameBasedSaverStatus(_LoadStatus): if context.executing_eagerly(): return # Nothing to do, variables are restored on creation. if session is None: - session = keras_backend.get_session() + session = get_session() with ops.device("/cpu:0"): saveables = self._gather_saveable_objects() v1_saver_lib.Saver(saveables).restore( @@ -1109,7 +1117,7 @@ class TrackableSaver(object): if not use_session: session = None elif session is None: - session = keras_backend.get_session() + session = get_session() if session: return session.run(save_path, feed_dict=feed_dict) @@ -1492,7 +1500,7 @@ class CheckpointV1(tracking.AutoTrackable): "update metadata. tf.train.latest_checkpoint and related APIs will " "not see this checkpoint.") if session is None: - session = keras_backend.get_session() + session = get_session() if self._save_counter is None: # When graph building, if this is a new save counter variable then it # needs to be initialized before assign_add. This is only an issue if @@ -1822,7 +1830,7 @@ class Checkpoint(tracking.AutoTrackable): "tf.train.Checkpoint.write(), a lower-level API which does not " "update metadata. tf.train.latest_checkpoint and related APIs will " "not see this checkpoint.") - session = keras_backend.get_session() + session = get_session() if self._save_counter is None: # When graph building, if this is a new save counter variable then it # needs to be initialized before assign_add. This is only an issue if diff --git a/tensorflow/python/training/tracking/util_test.py b/tensorflow/python/training/tracking/util_test.py index 1ec5466e0ab..3969d259550 100644 --- a/tensorflow/python/training/tracking/util_test.py +++ b/tensorflow/python/training/tracking/util_test.py @@ -1412,7 +1412,9 @@ class TemplateTests(parameterized.TestCase, test.TestCase): optimizer.minimize(v1_save.read_value, var_list=[v1_save]) self.evaluate([v.initializer for v in save_template.variables]) - self.evaluate([v.initializer for v in optimizer.variables()]) + optimizer_variables = optimizer.variables() + list( + optimizer._hyper.values()) + self.evaluate([v.initializer for v in optimizer_variables]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir()