Checkpointing: Don't use Keras get_session if we can avoid it since it has side effects
Some people rely on their Keras session being used, but automatic initialization is causing device placement issues. So if there's a regular session registered we'll avoid the Keras initialization. PiperOrigin-RevId: 251353716
This commit is contained in:
parent
2720abdb86
commit
eabe66e503
@ -64,6 +64,14 @@ keras_backend = lazy_loader.LazyLoader(
|
||||
"tensorflow.python.keras.backend")
|
||||
|
||||
|
||||
def get_session():
|
||||
# Prefer TF's default session since get_session from Keras has side-effects.
|
||||
session = ops.get_default_session()
|
||||
if session is None:
|
||||
session = keras_backend.get_session()
|
||||
return session
|
||||
|
||||
|
||||
class _ObjectGraphProtoPrettyPrinter(object):
|
||||
"""Lazily traverses an object graph proto to pretty print names.
|
||||
|
||||
@ -613,7 +621,7 @@ def streaming_restore(status, session=None):
|
||||
# Streaming restore is the default/only behavior when executing eagerly.
|
||||
return
|
||||
if session is None:
|
||||
session = keras_backend.get_session()
|
||||
session = get_session()
|
||||
if isinstance(status, NameBasedSaverStatus):
|
||||
raise NotImplementedError(
|
||||
"Streaming restore not supported from name-based checkpoints when "
|
||||
@ -756,7 +764,7 @@ class CheckpointLoadStatus(_LoadStatus):
|
||||
if context.executing_eagerly():
|
||||
return # Run eagerly
|
||||
if session is None:
|
||||
session = keras_backend.get_session()
|
||||
session = get_session()
|
||||
session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
|
||||
|
||||
def initialize_or_restore(self, session=None):
|
||||
@ -777,7 +785,7 @@ class CheckpointLoadStatus(_LoadStatus):
|
||||
if context.executing_eagerly():
|
||||
return # Initialization and restoration ops are run eagerly
|
||||
if session is None:
|
||||
session = keras_backend.get_session()
|
||||
session = get_session()
|
||||
all_objects = self._graph_view.list_objects()
|
||||
already_initialized_objects = object_identity.ObjectIdentitySet(
|
||||
self._checkpoint.object_by_proto_id.values())
|
||||
@ -855,7 +863,7 @@ class InitializationOnlyStatus(_LoadStatus):
|
||||
if context.executing_eagerly():
|
||||
return # run eagerly
|
||||
if session is None:
|
||||
session = keras_backend.get_session()
|
||||
session = get_session()
|
||||
trackable_objects = self._graph_view.list_objects()
|
||||
initializers = [
|
||||
c.initializer for c in trackable_objects
|
||||
@ -937,7 +945,7 @@ class NameBasedSaverStatus(_LoadStatus):
|
||||
if context.executing_eagerly():
|
||||
return # Nothing to do, variables are restored on creation.
|
||||
if session is None:
|
||||
session = keras_backend.get_session()
|
||||
session = get_session()
|
||||
with ops.device("/cpu:0"):
|
||||
saveables = self._gather_saveable_objects()
|
||||
v1_saver_lib.Saver(saveables).restore(
|
||||
@ -1109,7 +1117,7 @@ class TrackableSaver(object):
|
||||
if not use_session:
|
||||
session = None
|
||||
elif session is None:
|
||||
session = keras_backend.get_session()
|
||||
session = get_session()
|
||||
|
||||
if session:
|
||||
return session.run(save_path, feed_dict=feed_dict)
|
||||
@ -1492,7 +1500,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
"update metadata. tf.train.latest_checkpoint and related APIs will "
|
||||
"not see this checkpoint.")
|
||||
if session is None:
|
||||
session = keras_backend.get_session()
|
||||
session = get_session()
|
||||
if self._save_counter is None:
|
||||
# When graph building, if this is a new save counter variable then it
|
||||
# needs to be initialized before assign_add. This is only an issue if
|
||||
@ -1822,7 +1830,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
"tf.train.Checkpoint.write(), a lower-level API which does not "
|
||||
"update metadata. tf.train.latest_checkpoint and related APIs will "
|
||||
"not see this checkpoint.")
|
||||
session = keras_backend.get_session()
|
||||
session = get_session()
|
||||
if self._save_counter is None:
|
||||
# When graph building, if this is a new save counter variable then it
|
||||
# needs to be initialized before assign_add. This is only an issue if
|
||||
|
@ -1412,7 +1412,9 @@ class TemplateTests(parameterized.TestCase, test.TestCase):
|
||||
optimizer.minimize(v1_save.read_value,
|
||||
var_list=[v1_save])
|
||||
self.evaluate([v.initializer for v in save_template.variables])
|
||||
self.evaluate([v.initializer for v in optimizer.variables()])
|
||||
optimizer_variables = optimizer.variables() + list(
|
||||
optimizer._hyper.values())
|
||||
self.evaluate([v.initializer for v in optimizer_variables])
|
||||
self.evaluate(v1_save.assign([12.]))
|
||||
self.evaluate(v2_save.assign([14.]))
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
|
Loading…
Reference in New Issue
Block a user