Checkpointing: Don't use Keras get_session if we can avoid it since it has side effects

Some people rely on their Keras session being used, but automatic initialization is causing device placement issues. So if there's a regular session registered we'll avoid the Keras initialization.

PiperOrigin-RevId: 251353716
This commit is contained in:
Allen Lavoie 2019-06-03 18:47:25 -07:00 committed by TensorFlower Gardener
parent 2720abdb86
commit eabe66e503
2 changed files with 19 additions and 9 deletions

View File

@ -64,6 +64,14 @@ keras_backend = lazy_loader.LazyLoader(
"tensorflow.python.keras.backend")
def get_session():
# Prefer TF's default session since get_session from Keras has side-effects.
session = ops.get_default_session()
if session is None:
session = keras_backend.get_session()
return session
class _ObjectGraphProtoPrettyPrinter(object):
"""Lazily traverses an object graph proto to pretty print names.
@ -613,7 +621,7 @@ def streaming_restore(status, session=None):
# Streaming restore is the default/only behavior when executing eagerly.
return
if session is None:
session = keras_backend.get_session()
session = get_session()
if isinstance(status, NameBasedSaverStatus):
raise NotImplementedError(
"Streaming restore not supported from name-based checkpoints when "
@ -756,7 +764,7 @@ class CheckpointLoadStatus(_LoadStatus):
if context.executing_eagerly():
return # Run eagerly
if session is None:
session = keras_backend.get_session()
session = get_session()
session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
def initialize_or_restore(self, session=None):
@ -777,7 +785,7 @@ class CheckpointLoadStatus(_LoadStatus):
if context.executing_eagerly():
return # Initialization and restoration ops are run eagerly
if session is None:
session = keras_backend.get_session()
session = get_session()
all_objects = self._graph_view.list_objects()
already_initialized_objects = object_identity.ObjectIdentitySet(
self._checkpoint.object_by_proto_id.values())
@ -855,7 +863,7 @@ class InitializationOnlyStatus(_LoadStatus):
if context.executing_eagerly():
return # run eagerly
if session is None:
session = keras_backend.get_session()
session = get_session()
trackable_objects = self._graph_view.list_objects()
initializers = [
c.initializer for c in trackable_objects
@ -937,7 +945,7 @@ class NameBasedSaverStatus(_LoadStatus):
if context.executing_eagerly():
return # Nothing to do, variables are restored on creation.
if session is None:
session = keras_backend.get_session()
session = get_session()
with ops.device("/cpu:0"):
saveables = self._gather_saveable_objects()
v1_saver_lib.Saver(saveables).restore(
@ -1109,7 +1117,7 @@ class TrackableSaver(object):
if not use_session:
session = None
elif session is None:
session = keras_backend.get_session()
session = get_session()
if session:
return session.run(save_path, feed_dict=feed_dict)
@ -1492,7 +1500,7 @@ class CheckpointV1(tracking.AutoTrackable):
"update metadata. tf.train.latest_checkpoint and related APIs will "
"not see this checkpoint.")
if session is None:
session = keras_backend.get_session()
session = get_session()
if self._save_counter is None:
# When graph building, if this is a new save counter variable then it
# needs to be initialized before assign_add. This is only an issue if
@ -1822,7 +1830,7 @@ class Checkpoint(tracking.AutoTrackable):
"tf.train.Checkpoint.write(), a lower-level API which does not "
"update metadata. tf.train.latest_checkpoint and related APIs will "
"not see this checkpoint.")
session = keras_backend.get_session()
session = get_session()
if self._save_counter is None:
# When graph building, if this is a new save counter variable then it
# needs to be initialized before assign_add. This is only an issue if

View File

@ -1412,7 +1412,9 @@ class TemplateTests(parameterized.TestCase, test.TestCase):
optimizer.minimize(v1_save.read_value,
var_list=[v1_save])
self.evaluate([v.initializer for v in save_template.variables])
self.evaluate([v.initializer for v in optimizer.variables()])
optimizer_variables = optimizer.variables() + list(
optimizer._hyper.values())
self.evaluate([v.initializer for v in optimizer_variables])
self.evaluate(v1_save.assign([12.]))
self.evaluate(v2_save.assign([14.]))
checkpoint_directory = self.get_temp_dir()