Checkpointing: Don't use Keras get_session if we can avoid it since it has side effects

Some people rely on their Keras session being used, but automatic initialization is causing device placement issues. So if there's a regular session registered we'll avoid the Keras initialization.

PiperOrigin-RevId: 251353716
This commit is contained in:
Allen Lavoie 2019-06-03 18:47:25 -07:00 committed by TensorFlower Gardener
parent 2720abdb86
commit eabe66e503
2 changed files with 19 additions and 9 deletions

View File

@ -64,6 +64,14 @@ keras_backend = lazy_loader.LazyLoader(
"tensorflow.python.keras.backend") "tensorflow.python.keras.backend")
def get_session():
# Prefer TF's default session since get_session from Keras has side-effects.
session = ops.get_default_session()
if session is None:
session = keras_backend.get_session()
return session
class _ObjectGraphProtoPrettyPrinter(object): class _ObjectGraphProtoPrettyPrinter(object):
"""Lazily traverses an object graph proto to pretty print names. """Lazily traverses an object graph proto to pretty print names.
@ -613,7 +621,7 @@ def streaming_restore(status, session=None):
# Streaming restore is the default/only behavior when executing eagerly. # Streaming restore is the default/only behavior when executing eagerly.
return return
if session is None: if session is None:
session = keras_backend.get_session() session = get_session()
if isinstance(status, NameBasedSaverStatus): if isinstance(status, NameBasedSaverStatus):
raise NotImplementedError( raise NotImplementedError(
"Streaming restore not supported from name-based checkpoints when " "Streaming restore not supported from name-based checkpoints when "
@ -756,7 +764,7 @@ class CheckpointLoadStatus(_LoadStatus):
if context.executing_eagerly(): if context.executing_eagerly():
return # Run eagerly return # Run eagerly
if session is None: if session is None:
session = keras_backend.get_session() session = get_session()
session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
def initialize_or_restore(self, session=None): def initialize_or_restore(self, session=None):
@ -777,7 +785,7 @@ class CheckpointLoadStatus(_LoadStatus):
if context.executing_eagerly(): if context.executing_eagerly():
return # Initialization and restoration ops are run eagerly return # Initialization and restoration ops are run eagerly
if session is None: if session is None:
session = keras_backend.get_session() session = get_session()
all_objects = self._graph_view.list_objects() all_objects = self._graph_view.list_objects()
already_initialized_objects = object_identity.ObjectIdentitySet( already_initialized_objects = object_identity.ObjectIdentitySet(
self._checkpoint.object_by_proto_id.values()) self._checkpoint.object_by_proto_id.values())
@ -855,7 +863,7 @@ class InitializationOnlyStatus(_LoadStatus):
if context.executing_eagerly(): if context.executing_eagerly():
return # run eagerly return # run eagerly
if session is None: if session is None:
session = keras_backend.get_session() session = get_session()
trackable_objects = self._graph_view.list_objects() trackable_objects = self._graph_view.list_objects()
initializers = [ initializers = [
c.initializer for c in trackable_objects c.initializer for c in trackable_objects
@ -937,7 +945,7 @@ class NameBasedSaverStatus(_LoadStatus):
if context.executing_eagerly(): if context.executing_eagerly():
return # Nothing to do, variables are restored on creation. return # Nothing to do, variables are restored on creation.
if session is None: if session is None:
session = keras_backend.get_session() session = get_session()
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
saveables = self._gather_saveable_objects() saveables = self._gather_saveable_objects()
v1_saver_lib.Saver(saveables).restore( v1_saver_lib.Saver(saveables).restore(
@ -1109,7 +1117,7 @@ class TrackableSaver(object):
if not use_session: if not use_session:
session = None session = None
elif session is None: elif session is None:
session = keras_backend.get_session() session = get_session()
if session: if session:
return session.run(save_path, feed_dict=feed_dict) return session.run(save_path, feed_dict=feed_dict)
@ -1492,7 +1500,7 @@ class CheckpointV1(tracking.AutoTrackable):
"update metadata. tf.train.latest_checkpoint and related APIs will " "update metadata. tf.train.latest_checkpoint and related APIs will "
"not see this checkpoint.") "not see this checkpoint.")
if session is None: if session is None:
session = keras_backend.get_session() session = get_session()
if self._save_counter is None: if self._save_counter is None:
# When graph building, if this is a new save counter variable then it # When graph building, if this is a new save counter variable then it
# needs to be initialized before assign_add. This is only an issue if # needs to be initialized before assign_add. This is only an issue if
@ -1822,7 +1830,7 @@ class Checkpoint(tracking.AutoTrackable):
"tf.train.Checkpoint.write(), a lower-level API which does not " "tf.train.Checkpoint.write(), a lower-level API which does not "
"update metadata. tf.train.latest_checkpoint and related APIs will " "update metadata. tf.train.latest_checkpoint and related APIs will "
"not see this checkpoint.") "not see this checkpoint.")
session = keras_backend.get_session() session = get_session()
if self._save_counter is None: if self._save_counter is None:
# When graph building, if this is a new save counter variable then it # When graph building, if this is a new save counter variable then it
# needs to be initialized before assign_add. This is only an issue if # needs to be initialized before assign_add. This is only an issue if

View File

@ -1412,7 +1412,9 @@ class TemplateTests(parameterized.TestCase, test.TestCase):
optimizer.minimize(v1_save.read_value, optimizer.minimize(v1_save.read_value,
var_list=[v1_save]) var_list=[v1_save])
self.evaluate([v.initializer for v in save_template.variables]) self.evaluate([v.initializer for v in save_template.variables])
self.evaluate([v.initializer for v in optimizer.variables()]) optimizer_variables = optimizer.variables() + list(
optimizer._hyper.values())
self.evaluate([v.initializer for v in optimizer_variables])
self.evaluate(v1_save.assign([12.])) self.evaluate(v1_save.assign([12.]))
self.evaluate(v2_save.assign([14.])) self.evaluate(v2_save.assign([14.]))
checkpoint_directory = self.get_temp_dir() checkpoint_directory = self.get_temp_dir()