(1/2) Add option to not export Graph protos when training the Estimator.

Exporting the GraphDef and MetaGraphDef adds a considerable slowdown when working with large graphs. Often times, the user is only concerned with the checkpoints, so adding this option will help speed up training.

PiperOrigin-RevId: 274015240
This commit is contained in:
Katherine Wu 2019-10-10 12:32:53 -07:00 committed by TensorFlower Gardener
parent 090c30918e
commit 349e97ed6f
8 changed files with 108 additions and 18 deletions

View File

@ -520,7 +520,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
saver=None,
checkpoint_basename="model.ckpt",
scaffold=None,
listeners=None):
listeners=None,
save_graph_def=True):
"""Initializes a `CheckpointSaverHook`.
Args:
@ -533,6 +534,10 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
listeners: List of `CheckpointSaverListener` subclass instances. Used for
callbacks that run immediately before or after this hook saves the
checkpoint.
save_graph_def: Whether to save the GraphDef and MetaGraphDef to
`checkpoint_dir`. The GraphDef is saved after the session is created as
`graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
`model.ckpt-*.meta`.
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
@ -549,6 +554,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
every_secs=save_secs, every_steps=save_steps)
self._listeners = listeners or []
self._steps_per_run = 1
self._save_graph_def = save_graph_def
def _set_steps_per_run(self, steps_per_run):
self._steps_per_run = steps_per_run
@ -564,12 +570,13 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
def after_create_session(self, session, coord):
global_step = session.run(self._global_step_tensor)
# We do write graph and saver_def at the first call of before_run.
# We cannot do this in begin, since we let other hooks to change graph and
# add variables in begin. Graph is finalized after all begin calls.
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir, "graph.pbtxt")
if self._save_graph_def:
# We do write graph and saver_def at the first call of before_run.
# We cannot do this in begin, since we let other hooks to change graph and
# add variables in begin. Graph is finalized after all begin calls.
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir, "graph.pbtxt")
saver_def = self._get_saver().saver_def if self._get_saver() else None
graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def(
@ -608,7 +615,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
for l in self._listeners:
l.before_save(session, step)
self._get_saver().save(session, self._save_path, global_step=step)
self._get_saver().save(session, self._save_path, global_step=step,
write_meta_graph=self._save_graph_def)
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),

View File

@ -776,6 +776,48 @@ class CheckpointSaverHookTest(test.TestCase):
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
def test_save_graph_def(self):
with self.graph.as_default():
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_steps=1, scaffold=self.scaffold,
save_graph_def=True)
hook.begin()
self.scaffold.finalize()
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(self.scaffold.init_op)
hook.after_create_session(sess, None)
self.assertIn('graph.pbtxt', os.listdir(self.model_dir))
# Should have a single .meta file for step 0
self.assertLen(gfile.Glob(os.path.join(self.model_dir, '*.meta')), 1)
mon_sess.run(self.train_op)
self.assertLen(gfile.Glob(os.path.join(self.model_dir, '*.meta')), 2)
def test_save_graph_def_false(self):
with self.graph.as_default():
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_steps=1, scaffold=self.scaffold,
save_graph_def=False)
hook.begin()
self.scaffold.finalize()
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(self.scaffold.init_op)
hook.after_create_session(sess, None)
self.assertNotIn('graph.pbtxt', os.listdir(self.model_dir))
# Should have a single .meta file for step 0
self.assertEmpty(gfile.Glob(os.path.join(self.model_dir, '*.meta')))
mon_sess.run(self.train_op)
self.assertEmpty(gfile.Glob(os.path.join(self.model_dir, '*.meta')))
class CheckpointSaverHookMultiStepTest(test.TestCase):

View File

@ -332,7 +332,8 @@ def _create_monitored_session_with_worker_context(
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=None,
summary_dir=None):
summary_dir=None,
save_graph_def=True):
all_hooks = []
if hooks:
all_hooks.extend(hooks)
@ -406,14 +407,16 @@ def _create_monitored_session_with_worker_context(
checkpoint_dir,
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold))
scaffold=scaffold,
save_graph_def=save_graph_def))
elif tmpdir:
all_hooks.append(
basic_session_run_hooks.CheckpointSaverHook(
os.path.join(checkpoint_dir, tmpdir),
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold))
scaffold=scaffold,
save_graph_def=save_graph_def))
logging.info('all_hooks %r', all_hooks)
session_creator = worker_context.session_creator(
@ -443,7 +446,8 @@ def MonitoredTrainingSession(
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=USE_DEFAULT,
summary_dir=None):
summary_dir=None,
save_graph_def=True):
"""Creates a `MonitoredSession` for training.
For a chief, this utility sets proper session initializer/restorer. It also
@ -497,6 +501,10 @@ def MonitoredTrainingSession(
`save_checkpoint_secs` is used. Default not enabled.
summary_dir: A string. Optional path to a directory where to save
summaries. If None, checkpoint_dir is used instead.
save_graph_def: Whether to save the GraphDef and MetaGraphDef to
`checkpoint_dir`. The GraphDef is saved after the session is created as
`graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
`model.ckpt-*.meta`.
Returns:
A `MonitoredSession` object.
@ -536,7 +544,8 @@ def MonitoredTrainingSession(
log_step_count_steps=log_step_count_steps,
max_wait_secs=max_wait_secs,
save_checkpoint_steps=save_checkpoint_steps,
summary_dir=summary_dir)
summary_dir=summary_dir,
save_graph_def=save_graph_def)
if not is_chief:
session_creator = WorkerSessionCreator(
@ -584,7 +593,8 @@ def MonitoredTrainingSession(
checkpoint_dir,
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold))
scaffold=scaffold,
save_graph_def=save_graph_def))
if hooks:
all_hooks.extend(hooks)

View File

@ -399,6 +399,36 @@ class MonitoredTrainingSessionTest(test.TestCase):
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(0, session.run(gstep))
def test_save_graph_def(self):
logdir = _test_dir(self.get_temp_dir(), 'test_save_graph_def')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession(
is_chief=True,
checkpoint_dir=logdir,
save_checkpoint_steps=1,
save_graph_def=True) as session:
self.assertIn('graph.pbtxt', os.listdir(logdir))
self.assertLen(glob.glob(os.path.join(logdir, '*.meta')), 1)
session.run(new_gstep)
self.assertLen(glob.glob(os.path.join(logdir, '*.meta')), 2)
def test_save_graph_def_false(self):
logdir = _test_dir(self.get_temp_dir(), 'test_save_graph_def')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession(
is_chief=True,
checkpoint_dir=logdir,
save_checkpoint_steps=1,
save_graph_def=False) as session:
self.assertNotIn('graph.pbtxt', os.listdir(logdir))
self.assertEmpty(glob.glob(os.path.join(logdir, '*.meta')))
session.run(new_gstep)
self.assertEmpty(glob.glob(os.path.join(logdir, '*.meta')))
class MockExtended(object):

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "after_create_session"

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "after_create_session"

View File

@ -250,7 +250,7 @@ tf_module {
}
member_method {
name: "MonitoredTrainingSession"
argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'None\', \'120\', \'100\', \'7200\', \'<object object instance>\', \'None\'], "
argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'None\', \'120\', \'100\', \'7200\', \'<object object instance>\', \'None\', \'True\'], "
}
member_method {
name: "NewCheckpointReader"

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\', \'True\'], "
}
member_method {
name: "after_create_session"