Checkpointing: Don't use Keras get_session if we can avoid it since it has side effects
Some people rely on their Keras session being used, but automatic initialization is causing device placement issues. So if there's a regular session registered we'll avoid the Keras initialization. PiperOrigin-RevId: 251353716
This commit is contained in:
parent
2720abdb86
commit
eabe66e503
@ -64,6 +64,14 @@ keras_backend = lazy_loader.LazyLoader(
|
|||||||
"tensorflow.python.keras.backend")
|
"tensorflow.python.keras.backend")
|
||||||
|
|
||||||
|
|
||||||
|
def get_session():
|
||||||
|
# Prefer TF's default session since get_session from Keras has side-effects.
|
||||||
|
session = ops.get_default_session()
|
||||||
|
if session is None:
|
||||||
|
session = keras_backend.get_session()
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
class _ObjectGraphProtoPrettyPrinter(object):
|
class _ObjectGraphProtoPrettyPrinter(object):
|
||||||
"""Lazily traverses an object graph proto to pretty print names.
|
"""Lazily traverses an object graph proto to pretty print names.
|
||||||
|
|
||||||
@ -613,7 +621,7 @@ def streaming_restore(status, session=None):
|
|||||||
# Streaming restore is the default/only behavior when executing eagerly.
|
# Streaming restore is the default/only behavior when executing eagerly.
|
||||||
return
|
return
|
||||||
if session is None:
|
if session is None:
|
||||||
session = keras_backend.get_session()
|
session = get_session()
|
||||||
if isinstance(status, NameBasedSaverStatus):
|
if isinstance(status, NameBasedSaverStatus):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Streaming restore not supported from name-based checkpoints when "
|
"Streaming restore not supported from name-based checkpoints when "
|
||||||
@ -756,7 +764,7 @@ class CheckpointLoadStatus(_LoadStatus):
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return # Run eagerly
|
return # Run eagerly
|
||||||
if session is None:
|
if session is None:
|
||||||
session = keras_backend.get_session()
|
session = get_session()
|
||||||
session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
|
session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
|
||||||
|
|
||||||
def initialize_or_restore(self, session=None):
|
def initialize_or_restore(self, session=None):
|
||||||
@ -777,7 +785,7 @@ class CheckpointLoadStatus(_LoadStatus):
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return # Initialization and restoration ops are run eagerly
|
return # Initialization and restoration ops are run eagerly
|
||||||
if session is None:
|
if session is None:
|
||||||
session = keras_backend.get_session()
|
session = get_session()
|
||||||
all_objects = self._graph_view.list_objects()
|
all_objects = self._graph_view.list_objects()
|
||||||
already_initialized_objects = object_identity.ObjectIdentitySet(
|
already_initialized_objects = object_identity.ObjectIdentitySet(
|
||||||
self._checkpoint.object_by_proto_id.values())
|
self._checkpoint.object_by_proto_id.values())
|
||||||
@ -855,7 +863,7 @@ class InitializationOnlyStatus(_LoadStatus):
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return # run eagerly
|
return # run eagerly
|
||||||
if session is None:
|
if session is None:
|
||||||
session = keras_backend.get_session()
|
session = get_session()
|
||||||
trackable_objects = self._graph_view.list_objects()
|
trackable_objects = self._graph_view.list_objects()
|
||||||
initializers = [
|
initializers = [
|
||||||
c.initializer for c in trackable_objects
|
c.initializer for c in trackable_objects
|
||||||
@ -937,7 +945,7 @@ class NameBasedSaverStatus(_LoadStatus):
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return # Nothing to do, variables are restored on creation.
|
return # Nothing to do, variables are restored on creation.
|
||||||
if session is None:
|
if session is None:
|
||||||
session = keras_backend.get_session()
|
session = get_session()
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
saveables = self._gather_saveable_objects()
|
saveables = self._gather_saveable_objects()
|
||||||
v1_saver_lib.Saver(saveables).restore(
|
v1_saver_lib.Saver(saveables).restore(
|
||||||
@ -1109,7 +1117,7 @@ class TrackableSaver(object):
|
|||||||
if not use_session:
|
if not use_session:
|
||||||
session = None
|
session = None
|
||||||
elif session is None:
|
elif session is None:
|
||||||
session = keras_backend.get_session()
|
session = get_session()
|
||||||
|
|
||||||
if session:
|
if session:
|
||||||
return session.run(save_path, feed_dict=feed_dict)
|
return session.run(save_path, feed_dict=feed_dict)
|
||||||
@ -1492,7 +1500,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
"update metadata. tf.train.latest_checkpoint and related APIs will "
|
"update metadata. tf.train.latest_checkpoint and related APIs will "
|
||||||
"not see this checkpoint.")
|
"not see this checkpoint.")
|
||||||
if session is None:
|
if session is None:
|
||||||
session = keras_backend.get_session()
|
session = get_session()
|
||||||
if self._save_counter is None:
|
if self._save_counter is None:
|
||||||
# When graph building, if this is a new save counter variable then it
|
# When graph building, if this is a new save counter variable then it
|
||||||
# needs to be initialized before assign_add. This is only an issue if
|
# needs to be initialized before assign_add. This is only an issue if
|
||||||
@ -1822,7 +1830,7 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
"tf.train.Checkpoint.write(), a lower-level API which does not "
|
"tf.train.Checkpoint.write(), a lower-level API which does not "
|
||||||
"update metadata. tf.train.latest_checkpoint and related APIs will "
|
"update metadata. tf.train.latest_checkpoint and related APIs will "
|
||||||
"not see this checkpoint.")
|
"not see this checkpoint.")
|
||||||
session = keras_backend.get_session()
|
session = get_session()
|
||||||
if self._save_counter is None:
|
if self._save_counter is None:
|
||||||
# When graph building, if this is a new save counter variable then it
|
# When graph building, if this is a new save counter variable then it
|
||||||
# needs to be initialized before assign_add. This is only an issue if
|
# needs to be initialized before assign_add. This is only an issue if
|
||||||
|
@ -1412,7 +1412,9 @@ class TemplateTests(parameterized.TestCase, test.TestCase):
|
|||||||
optimizer.minimize(v1_save.read_value,
|
optimizer.minimize(v1_save.read_value,
|
||||||
var_list=[v1_save])
|
var_list=[v1_save])
|
||||||
self.evaluate([v.initializer for v in save_template.variables])
|
self.evaluate([v.initializer for v in save_template.variables])
|
||||||
self.evaluate([v.initializer for v in optimizer.variables()])
|
optimizer_variables = optimizer.variables() + list(
|
||||||
|
optimizer._hyper.values())
|
||||||
|
self.evaluate([v.initializer for v in optimizer_variables])
|
||||||
self.evaluate(v1_save.assign([12.]))
|
self.evaluate(v1_save.assign([12.]))
|
||||||
self.evaluate(v2_save.assign([14.]))
|
self.evaluate(v2_save.assign([14.]))
|
||||||
checkpoint_directory = self.get_temp_dir()
|
checkpoint_directory = self.get_temp_dir()
|
||||||
|
Loading…
Reference in New Issue
Block a user