(1/2) Add option to not export Graph protos when training the Estimator.
Exporting the GraphDef and MetaGraphDef adds a considerable slowdown when working with large graphs. Often times, the user is only concerned with the checkpoints, so adding this option will help speed up training. PiperOrigin-RevId: 274015240
This commit is contained in:
parent
090c30918e
commit
349e97ed6f
tensorflow
python/training
basic_session_run_hooks.pybasic_session_run_hooks_test.pymonitored_session.pymonitored_session_test.py
tools/api/golden
@ -520,7 +520,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
||||
saver=None,
|
||||
checkpoint_basename="model.ckpt",
|
||||
scaffold=None,
|
||||
listeners=None):
|
||||
listeners=None,
|
||||
save_graph_def=True):
|
||||
"""Initializes a `CheckpointSaverHook`.
|
||||
|
||||
Args:
|
||||
@ -533,6 +534,10 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
||||
listeners: List of `CheckpointSaverListener` subclass instances. Used for
|
||||
callbacks that run immediately before or after this hook saves the
|
||||
checkpoint.
|
||||
save_graph_def: Whether to save the GraphDef and MetaGraphDef to
|
||||
`checkpoint_dir`. The GraphDef is saved after the session is created as
|
||||
`graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
|
||||
`model.ckpt-*.meta`.
|
||||
|
||||
Raises:
|
||||
ValueError: One of `save_steps` or `save_secs` should be set.
|
||||
@ -549,6 +554,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
||||
every_secs=save_secs, every_steps=save_steps)
|
||||
self._listeners = listeners or []
|
||||
self._steps_per_run = 1
|
||||
self._save_graph_def = save_graph_def
|
||||
|
||||
def _set_steps_per_run(self, steps_per_run):
|
||||
self._steps_per_run = steps_per_run
|
||||
@ -564,12 +570,13 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
global_step = session.run(self._global_step_tensor)
|
||||
# We do write graph and saver_def at the first call of before_run.
|
||||
# We cannot do this in begin, since we let other hooks to change graph and
|
||||
# add variables in begin. Graph is finalized after all begin calls.
|
||||
training_util.write_graph(
|
||||
ops.get_default_graph().as_graph_def(add_shapes=True),
|
||||
self._checkpoint_dir, "graph.pbtxt")
|
||||
if self._save_graph_def:
|
||||
# We do write graph and saver_def at the first call of before_run.
|
||||
# We cannot do this in begin, since we let other hooks to change graph and
|
||||
# add variables in begin. Graph is finalized after all begin calls.
|
||||
training_util.write_graph(
|
||||
ops.get_default_graph().as_graph_def(add_shapes=True),
|
||||
self._checkpoint_dir, "graph.pbtxt")
|
||||
saver_def = self._get_saver().saver_def if self._get_saver() else None
|
||||
graph = ops.get_default_graph()
|
||||
meta_graph_def = meta_graph.create_meta_graph_def(
|
||||
@ -608,7 +615,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
||||
for l in self._listeners:
|
||||
l.before_save(session, step)
|
||||
|
||||
self._get_saver().save(session, self._save_path, global_step=step)
|
||||
self._get_saver().save(session, self._save_path, global_step=step,
|
||||
write_meta_graph=self._save_graph_def)
|
||||
self._summary_writer.add_session_log(
|
||||
SessionLog(
|
||||
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
|
||||
|
@ -776,6 +776,48 @@ class CheckpointSaverHookTest(test.TestCase):
|
||||
checkpoint_utils.load_variable(self.model_dir,
|
||||
self.global_step.name))
|
||||
|
||||
def test_save_graph_def(self):
|
||||
with self.graph.as_default():
|
||||
hook = basic_session_run_hooks.CheckpointSaverHook(
|
||||
self.model_dir, save_steps=1, scaffold=self.scaffold,
|
||||
save_graph_def=True)
|
||||
hook.begin()
|
||||
self.scaffold.finalize()
|
||||
with session_lib.Session() as sess:
|
||||
sess.run(self.scaffold.init_op)
|
||||
mon_sess = monitored_session._HookedSession(sess, [hook])
|
||||
sess.run(self.scaffold.init_op)
|
||||
hook.after_create_session(sess, None)
|
||||
|
||||
self.assertIn('graph.pbtxt', os.listdir(self.model_dir))
|
||||
# Should have a single .meta file for step 0
|
||||
self.assertLen(gfile.Glob(os.path.join(self.model_dir, '*.meta')), 1)
|
||||
|
||||
mon_sess.run(self.train_op)
|
||||
self.assertLen(gfile.Glob(os.path.join(self.model_dir, '*.meta')), 2)
|
||||
|
||||
def test_save_graph_def_false(self):
|
||||
with self.graph.as_default():
|
||||
hook = basic_session_run_hooks.CheckpointSaverHook(
|
||||
self.model_dir, save_steps=1, scaffold=self.scaffold,
|
||||
save_graph_def=False)
|
||||
hook.begin()
|
||||
self.scaffold.finalize()
|
||||
with session_lib.Session() as sess:
|
||||
sess.run(self.scaffold.init_op)
|
||||
mon_sess = monitored_session._HookedSession(sess, [hook])
|
||||
sess.run(self.scaffold.init_op)
|
||||
hook.after_create_session(sess, None)
|
||||
|
||||
self.assertNotIn('graph.pbtxt', os.listdir(self.model_dir))
|
||||
# Should have a single .meta file for step 0
|
||||
self.assertEmpty(gfile.Glob(os.path.join(self.model_dir, '*.meta')))
|
||||
|
||||
mon_sess.run(self.train_op)
|
||||
self.assertEmpty(gfile.Glob(os.path.join(self.model_dir, '*.meta')))
|
||||
|
||||
|
||||
|
||||
|
||||
class CheckpointSaverHookMultiStepTest(test.TestCase):
|
||||
|
||||
|
@ -332,7 +332,8 @@ def _create_monitored_session_with_worker_context(
|
||||
log_step_count_steps=100,
|
||||
max_wait_secs=7200,
|
||||
save_checkpoint_steps=None,
|
||||
summary_dir=None):
|
||||
summary_dir=None,
|
||||
save_graph_def=True):
|
||||
all_hooks = []
|
||||
if hooks:
|
||||
all_hooks.extend(hooks)
|
||||
@ -406,14 +407,16 @@ def _create_monitored_session_with_worker_context(
|
||||
checkpoint_dir,
|
||||
save_steps=save_checkpoint_steps,
|
||||
save_secs=save_checkpoint_secs,
|
||||
scaffold=scaffold))
|
||||
scaffold=scaffold,
|
||||
save_graph_def=save_graph_def))
|
||||
elif tmpdir:
|
||||
all_hooks.append(
|
||||
basic_session_run_hooks.CheckpointSaverHook(
|
||||
os.path.join(checkpoint_dir, tmpdir),
|
||||
save_steps=save_checkpoint_steps,
|
||||
save_secs=save_checkpoint_secs,
|
||||
scaffold=scaffold))
|
||||
scaffold=scaffold,
|
||||
save_graph_def=save_graph_def))
|
||||
|
||||
logging.info('all_hooks %r', all_hooks)
|
||||
session_creator = worker_context.session_creator(
|
||||
@ -443,7 +446,8 @@ def MonitoredTrainingSession(
|
||||
log_step_count_steps=100,
|
||||
max_wait_secs=7200,
|
||||
save_checkpoint_steps=USE_DEFAULT,
|
||||
summary_dir=None):
|
||||
summary_dir=None,
|
||||
save_graph_def=True):
|
||||
"""Creates a `MonitoredSession` for training.
|
||||
|
||||
For a chief, this utility sets proper session initializer/restorer. It also
|
||||
@ -497,6 +501,10 @@ def MonitoredTrainingSession(
|
||||
`save_checkpoint_secs` is used. Default not enabled.
|
||||
summary_dir: A string. Optional path to a directory where to save
|
||||
summaries. If None, checkpoint_dir is used instead.
|
||||
save_graph_def: Whether to save the GraphDef and MetaGraphDef to
|
||||
`checkpoint_dir`. The GraphDef is saved after the session is created as
|
||||
`graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
|
||||
`model.ckpt-*.meta`.
|
||||
|
||||
Returns:
|
||||
A `MonitoredSession` object.
|
||||
@ -536,7 +544,8 @@ def MonitoredTrainingSession(
|
||||
log_step_count_steps=log_step_count_steps,
|
||||
max_wait_secs=max_wait_secs,
|
||||
save_checkpoint_steps=save_checkpoint_steps,
|
||||
summary_dir=summary_dir)
|
||||
summary_dir=summary_dir,
|
||||
save_graph_def=save_graph_def)
|
||||
|
||||
if not is_chief:
|
||||
session_creator = WorkerSessionCreator(
|
||||
@ -584,7 +593,8 @@ def MonitoredTrainingSession(
|
||||
checkpoint_dir,
|
||||
save_steps=save_checkpoint_steps,
|
||||
save_secs=save_checkpoint_secs,
|
||||
scaffold=scaffold))
|
||||
scaffold=scaffold,
|
||||
save_graph_def=save_graph_def))
|
||||
|
||||
if hooks:
|
||||
all_hooks.extend(hooks)
|
||||
|
@ -399,6 +399,36 @@ class MonitoredTrainingSessionTest(test.TestCase):
|
||||
is_chief=True, checkpoint_dir=logdir) as session:
|
||||
self.assertEqual(0, session.run(gstep))
|
||||
|
||||
def test_save_graph_def(self):
|
||||
logdir = _test_dir(self.get_temp_dir(), 'test_save_graph_def')
|
||||
with ops.Graph().as_default():
|
||||
gstep = training_util.get_or_create_global_step()
|
||||
new_gstep = state_ops.assign_add(gstep, 1)
|
||||
with monitored_session.MonitoredTrainingSession(
|
||||
is_chief=True,
|
||||
checkpoint_dir=logdir,
|
||||
save_checkpoint_steps=1,
|
||||
save_graph_def=True) as session:
|
||||
self.assertIn('graph.pbtxt', os.listdir(logdir))
|
||||
self.assertLen(glob.glob(os.path.join(logdir, '*.meta')), 1)
|
||||
session.run(new_gstep)
|
||||
self.assertLen(glob.glob(os.path.join(logdir, '*.meta')), 2)
|
||||
|
||||
def test_save_graph_def_false(self):
|
||||
logdir = _test_dir(self.get_temp_dir(), 'test_save_graph_def')
|
||||
with ops.Graph().as_default():
|
||||
gstep = training_util.get_or_create_global_step()
|
||||
new_gstep = state_ops.assign_add(gstep, 1)
|
||||
with monitored_session.MonitoredTrainingSession(
|
||||
is_chief=True,
|
||||
checkpoint_dir=logdir,
|
||||
save_checkpoint_steps=1,
|
||||
save_graph_def=False) as session:
|
||||
self.assertNotIn('graph.pbtxt', os.listdir(logdir))
|
||||
self.assertEmpty(glob.glob(os.path.join(logdir, '*.meta')))
|
||||
session.run(new_gstep)
|
||||
self.assertEmpty(glob.glob(os.path.join(logdir, '*.meta')))
|
||||
|
||||
|
||||
class MockExtended(object):
|
||||
|
||||
|
@ -5,7 +5,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "after_create_session"
|
||||
|
@ -5,7 +5,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "after_create_session"
|
||||
|
@ -250,7 +250,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "MonitoredTrainingSession"
|
||||
argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'None\', \'120\', \'100\', \'7200\', \'<object object instance>\', \'None\'], "
|
||||
argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'None\', \'120\', \'100\', \'7200\', \'<object object instance>\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "NewCheckpointReader"
|
||||
|
@ -5,7 +5,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "after_create_session"
|
||||
|
Loading…
Reference in New Issue
Block a user