diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index de1aa4921d8..f458b43381e 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -520,7 +520,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): saver=None, checkpoint_basename="model.ckpt", scaffold=None, - listeners=None): + listeners=None, + save_graph_def=True): """Initializes a `CheckpointSaverHook`. Args: @@ -533,6 +534,10 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): listeners: List of `CheckpointSaverListener` subclass instances. Used for callbacks that run immediately before or after this hook saves the checkpoint. + save_graph_def: Whether to save the GraphDef and MetaGraphDef to + `checkpoint_dir`. The GraphDef is saved after the session is created as + `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as + `model.ckpt-*.meta`. Raises: ValueError: One of `save_steps` or `save_secs` should be set. @@ -549,6 +554,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): every_secs=save_secs, every_steps=save_steps) self._listeners = listeners or [] self._steps_per_run = 1 + self._save_graph_def = save_graph_def def _set_steps_per_run(self, steps_per_run): self._steps_per_run = steps_per_run @@ -564,12 +570,13 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): def after_create_session(self, session, coord): global_step = session.run(self._global_step_tensor) - # We do write graph and saver_def at the first call of before_run. - # We cannot do this in begin, since we let other hooks to change graph and - # add variables in begin. Graph is finalized after all begin calls. - training_util.write_graph( - ops.get_default_graph().as_graph_def(add_shapes=True), - self._checkpoint_dir, "graph.pbtxt") + if self._save_graph_def: + # We do write graph and saver_def at the first call of before_run. + # We cannot do this in begin, since we let other hooks to change graph and + # add variables in begin. Graph is finalized after all begin calls. + training_util.write_graph( + ops.get_default_graph().as_graph_def(add_shapes=True), + self._checkpoint_dir, "graph.pbtxt") saver_def = self._get_saver().saver_def if self._get_saver() else None graph = ops.get_default_graph() meta_graph_def = meta_graph.create_meta_graph_def( @@ -608,7 +615,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): for l in self._listeners: l.before_save(session, step) - self._get_saver().save(session, self._save_path, global_step=step) + self._get_saver().save(session, self._save_path, global_step=step, + write_meta_graph=self._save_graph_def) self._summary_writer.add_session_log( SessionLog( status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index 3e1ccfed0dc..678fea89f9e 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -776,6 +776,48 @@ class CheckpointSaverHookTest(test.TestCase): checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) + def test_save_graph_def(self): + with self.graph.as_default(): + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, save_steps=1, scaffold=self.scaffold, + save_graph_def=True) + hook.begin() + self.scaffold.finalize() + with session_lib.Session() as sess: + sess.run(self.scaffold.init_op) + mon_sess = monitored_session._HookedSession(sess, [hook]) + sess.run(self.scaffold.init_op) + hook.after_create_session(sess, None) + + self.assertIn('graph.pbtxt', os.listdir(self.model_dir)) + # Should have a single .meta file for step 0 + self.assertLen(gfile.Glob(os.path.join(self.model_dir, '*.meta')), 1) + + mon_sess.run(self.train_op) + self.assertLen(gfile.Glob(os.path.join(self.model_dir, '*.meta')), 2) + + def test_save_graph_def_false(self): + with self.graph.as_default(): + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, save_steps=1, scaffold=self.scaffold, + save_graph_def=False) + hook.begin() + self.scaffold.finalize() + with session_lib.Session() as sess: + sess.run(self.scaffold.init_op) + mon_sess = monitored_session._HookedSession(sess, [hook]) + sess.run(self.scaffold.init_op) + hook.after_create_session(sess, None) + + self.assertNotIn('graph.pbtxt', os.listdir(self.model_dir)) + # Should have a single .meta file for step 0 + self.assertEmpty(gfile.Glob(os.path.join(self.model_dir, '*.meta'))) + + mon_sess.run(self.train_op) + self.assertEmpty(gfile.Glob(os.path.join(self.model_dir, '*.meta'))) + + + class CheckpointSaverHookMultiStepTest(test.TestCase): diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 3e1c3e9f73f..d77278e98f4 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -332,7 +332,8 @@ def _create_monitored_session_with_worker_context( log_step_count_steps=100, max_wait_secs=7200, save_checkpoint_steps=None, - summary_dir=None): + summary_dir=None, + save_graph_def=True): all_hooks = [] if hooks: all_hooks.extend(hooks) @@ -406,14 +407,16 @@ def _create_monitored_session_with_worker_context( checkpoint_dir, save_steps=save_checkpoint_steps, save_secs=save_checkpoint_secs, - scaffold=scaffold)) + scaffold=scaffold, + save_graph_def=save_graph_def)) elif tmpdir: all_hooks.append( basic_session_run_hooks.CheckpointSaverHook( os.path.join(checkpoint_dir, tmpdir), save_steps=save_checkpoint_steps, save_secs=save_checkpoint_secs, - scaffold=scaffold)) + scaffold=scaffold, + save_graph_def=save_graph_def)) logging.info('all_hooks %r', all_hooks) session_creator = worker_context.session_creator( @@ -443,7 +446,8 @@ def MonitoredTrainingSession( log_step_count_steps=100, max_wait_secs=7200, save_checkpoint_steps=USE_DEFAULT, - summary_dir=None): + summary_dir=None, + save_graph_def=True): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also @@ -497,6 +501,10 @@ def MonitoredTrainingSession( `save_checkpoint_secs` is used. Default not enabled. summary_dir: A string. Optional path to a directory where to save summaries. If None, checkpoint_dir is used instead. + save_graph_def: Whether to save the GraphDef and MetaGraphDef to + `checkpoint_dir`. The GraphDef is saved after the session is created as + `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as + `model.ckpt-*.meta`. Returns: A `MonitoredSession` object. @@ -536,7 +544,8 @@ def MonitoredTrainingSession( log_step_count_steps=log_step_count_steps, max_wait_secs=max_wait_secs, save_checkpoint_steps=save_checkpoint_steps, - summary_dir=summary_dir) + summary_dir=summary_dir, + save_graph_def=save_graph_def) if not is_chief: session_creator = WorkerSessionCreator( @@ -584,7 +593,8 @@ def MonitoredTrainingSession( checkpoint_dir, save_steps=save_checkpoint_steps, save_secs=save_checkpoint_secs, - scaffold=scaffold)) + scaffold=scaffold, + save_graph_def=save_graph_def)) if hooks: all_hooks.extend(hooks) diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index bf9d3a616c4..ee4105299ef 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -399,6 +399,36 @@ class MonitoredTrainingSessionTest(test.TestCase): is_chief=True, checkpoint_dir=logdir) as session: self.assertEqual(0, session.run(gstep)) + def test_save_graph_def(self): + logdir = _test_dir(self.get_temp_dir(), 'test_save_graph_def') + with ops.Graph().as_default(): + gstep = training_util.get_or_create_global_step() + new_gstep = state_ops.assign_add(gstep, 1) + with monitored_session.MonitoredTrainingSession( + is_chief=True, + checkpoint_dir=logdir, + save_checkpoint_steps=1, + save_graph_def=True) as session: + self.assertIn('graph.pbtxt', os.listdir(logdir)) + self.assertLen(glob.glob(os.path.join(logdir, '*.meta')), 1) + session.run(new_gstep) + self.assertLen(glob.glob(os.path.join(logdir, '*.meta')), 2) + + def test_save_graph_def_false(self): + logdir = _test_dir(self.get_temp_dir(), 'test_save_graph_def') + with ops.Graph().as_default(): + gstep = training_util.get_or_create_global_step() + new_gstep = state_ops.assign_add(gstep, 1) + with monitored_session.MonitoredTrainingSession( + is_chief=True, + checkpoint_dir=logdir, + save_checkpoint_steps=1, + save_graph_def=False) as session: + self.assertNotIn('graph.pbtxt', os.listdir(logdir)) + self.assertEmpty(glob.glob(os.path.join(logdir, '*.meta'))) + session.run(new_gstep) + self.assertEmpty(glob.glob(os.path.join(logdir, '*.meta'))) + class MockExtended(object): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-checkpoint-saver-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-checkpoint-saver-hook.pbtxt index f9e1504b494..5c87f49e5a2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-checkpoint-saver-hook.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-checkpoint-saver-hook.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\', \'True\'], " } member_method { name: "after_create_session" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-saver-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-saver-hook.pbtxt index c3037baa8c9..7ca7e218382 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-saver-hook.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-saver-hook.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\', \'True\'], " } member_method { name: "after_create_session" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt index 3527de0bf30..c71bc4af3ec 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt @@ -250,7 +250,7 @@ tf_module { } member_method { name: "MonitoredTrainingSession" - argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'\', \'\', \'\', \'None\', \'120\', \'100\', \'7200\', \'\', \'None\'], " + argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'\', \'\', \'\', \'None\', \'120\', \'100\', \'7200\', \'\', \'None\', \'True\'], " } member_method { name: "NewCheckpointReader" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-checkpoint-saver-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-checkpoint-saver-hook.pbtxt index f9e1504b494..5c87f49e5a2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-checkpoint-saver-hook.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-checkpoint-saver-hook.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\', \'save_graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\', \'True\'], " } member_method { name: "after_create_session"