From 8d7f9e83289e1d188036d6ab385792c3e0b898ac Mon Sep 17 00:00:00 2001 From: RJ Skerry-Ryan Date: Mon, 23 Sep 2019 20:58:06 -0700 Subject: [PATCH] Add SavedModelEstimator support for saved models with assets. The main op execution in SavedModelEstimator did not provide a feed_dict to rewrite asset tensors with their correct filepaths. This change fixes SavedModelEstimator to mirror the logic in SavedModelLoader.run_init_ops. PiperOrigin-RevId: 270822926 --- tensorflow/python/training/monitored_session.py | 16 +++++++++++++++- .../python/training/monitored_session_test.py | 9 +++++++++ tensorflow/python/training/session_manager.py | 9 +++++++-- .../python/training/session_manager_test.py | 12 ++++++++++++ .../golden/v1/tensorflow.train.-scaffold.pbtxt | 6 +++++- .../v1/tensorflow.train.-session-manager.pbtxt | 2 +- 6 files changed, 49 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index ac08da6704c..3e1c3e9f73f 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -113,7 +113,8 @@ class Scaffold(object): local_init_op=None, summary_op=None, saver=None, - copy_from_scaffold=None): + copy_from_scaffold=None, + local_init_feed_dict=None): """Create a scaffold. Args: @@ -146,6 +147,8 @@ class Scaffold(object): match the variable through the other `Model`. copy_from_scaffold: Optional scaffold object to copy fields from. Its fields will be overwritten by the provided fields in this function. + local_init_feed_dict: Optional session feed dictionary to use when running + the local_init_op. """ if copy_from_scaffold is not None: if not isinstance(copy_from_scaffold, Scaffold): @@ -162,6 +165,8 @@ class Scaffold(object): ready_for_local_init_op = coalesce( ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op) local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op) + local_init_feed_dict = coalesce(local_init_feed_dict, + copy_from_scaffold.local_init_feed_dict) summary_op = coalesce(summary_op, copy_from_scaffold.summary_op) saver = coalesce(saver, copy_from_scaffold.saver) @@ -178,6 +183,7 @@ class Scaffold(object): self._ready_op = ready_op self._ready_for_local_init_op = ready_for_local_init_op self._local_init_op = local_init_op + self._local_init_feed_dict = local_init_feed_dict self._summary_op = summary_op self._saver = saver @@ -260,6 +266,10 @@ class Scaffold(object): def local_init_op(self): return self._local_init_op + @property + def local_init_feed_dict(self): + return self._local_init_feed_dict + @property def summary_op(self): return self._summary_op @@ -624,11 +634,13 @@ class ChiefSessionCreator(SessionCreator): self._config = config def _get_session_manager(self): + """Gets or creates a SessionManager.""" if self._session_manager: return self._session_manager self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, + local_init_feed_dict=self._scaffold.local_init_feed_dict, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) @@ -672,11 +684,13 @@ class WorkerSessionCreator(SessionCreator): self._max_wait_secs = max_wait_secs def _get_session_manager(self): + """Gets or creates a SessionManager.""" if self._session_manager: return self._session_manager self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, + local_init_feed_dict=self._scaffold.local_init_feed_dict, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 095a524cb42..bf9d3a616c4 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -88,6 +88,7 @@ class ScaffoldTest(test.TestCase): self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor)) self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor)) self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation)) + self.assertEqual(None, scaffold.local_init_feed_dict) self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver)) with self.cached_session() as sess: self.assertItemsEqual([b'my_var', b'my_local_var'], @@ -110,6 +111,7 @@ class ScaffoldTest(test.TestCase): self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor)) self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor)) self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation)) + self.assertEqual(None, scaffold.local_init_feed_dict) self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver)) def test_caches_values(self): @@ -145,6 +147,7 @@ class ScaffoldTest(test.TestCase): ready_op=5, ready_for_local_init_op=6, local_init_op=7, + local_init_feed_dict=8, saver=saver) scaffold.finalize() self.assertEqual(2, scaffold.init_op) @@ -153,6 +156,7 @@ class ScaffoldTest(test.TestCase): self.assertEqual(5, scaffold.ready_op) self.assertEqual(6, scaffold.ready_for_local_init_op) self.assertEqual(7, scaffold.local_init_op) + self.assertEqual(8, scaffold.local_init_feed_dict) self.assertEqual(saver, scaffold.saver) def test_graph_is_finalized(self): @@ -175,6 +179,7 @@ class ScaffoldTest(test.TestCase): ready_op=5, ready_for_local_init_op=6, local_init_op=7, + local_init_feed_dict=8, saver=saver, copy_from_scaffold=scaffold1) @@ -185,6 +190,7 @@ class ScaffoldTest(test.TestCase): self.assertEqual(5, scaffold2.ready_op) self.assertEqual(6, scaffold2.ready_for_local_init_op) self.assertEqual(7, scaffold2.local_init_op) + self.assertEqual(8, scaffold2.local_init_feed_dict) self.assertEqual(saver, scaffold2.saver) def test_new_scaffold_from_existing_scaffold(self): @@ -198,6 +204,7 @@ class ScaffoldTest(test.TestCase): ready_op=5, ready_for_local_init_op=6, local_init_op=7, + local_init_feed_dict=8, saver=saver) scaffold2 = monitored_session.Scaffold( @@ -207,6 +214,7 @@ class ScaffoldTest(test.TestCase): ready_op=10, ready_for_local_init_op=12, local_init_op=14, + local_init_feed_dict=15, saver=saver, copy_from_scaffold=scaffold1) @@ -217,6 +225,7 @@ class ScaffoldTest(test.TestCase): self.assertEqual(10, scaffold2.ready_op) self.assertEqual(12, scaffold2.ready_for_local_init_op) self.assertEqual(14, scaffold2.local_init_op) + self.assertEqual(15, scaffold2.local_init_feed_dict) self.assertEqual(saver, scaffold2.saver) def test_copy_from_scaffold_is_scaffold(self): diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py index 104247e60ec..9c2db27af2c 100644 --- a/tensorflow/python/training/session_manager.py +++ b/tensorflow/python/training/session_manager.py @@ -97,7 +97,8 @@ class SessionManager(object): ready_for_local_init_op=None, graph=None, recovery_wait_secs=30, - local_init_run_options=None): + local_init_run_options=None, + local_init_feed_dict=None): """Creates a SessionManager. The `local_init_op` is an `Operation` that is run always after a new session @@ -131,6 +132,8 @@ class SessionManager(object): recovery_wait_secs: Seconds between checks for the model to be ready. local_init_run_options: RunOptions to be passed to session.run when executing the local_init_op. + local_init_feed_dict: Optional session feed dictionary to use when running + the local_init_op. Raises: ValueError: If ready_for_local_init_op is not None but local_init_op is @@ -146,6 +149,7 @@ class SessionManager(object): self._recovery_wait_secs = recovery_wait_secs self._target = None self._local_init_run_options = local_init_run_options + self._local_init_feed_dict = local_init_feed_dict if ready_for_local_init_op is not None and local_init_op is None: raise ValueError("If you pass a ready_for_local_init_op " "you must also pass a local_init_op " @@ -498,7 +502,8 @@ class SessionManager(object): is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) if is_ready_for_local_init: logging.info("Running local_init_op.") - sess.run(self._local_init_op, options=self._local_init_run_options) + sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict, + options=self._local_init_run_options) logging.info("Done running local_init_op.") return True, None else: diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py index 1ceddf7a170..9d7381d08e0 100644 --- a/tensorflow/python/training/session_manager_test.py +++ b/tensorflow/python/training/session_manager_test.py @@ -69,6 +69,18 @@ class SessionManagerTest(test.TestCase): "", init_fn=lambda sess: sess.run(v.initializer)) self.assertAllClose([125], sess.run(v)) + def testPrepareSessionSucceedsWithLocalInitFeedDict(self): + with ops.Graph().as_default(): + p = array_ops.placeholder(dtypes.float32, shape=(3,)) + v = variables.VariableV1(p, name="v", + collections=[ops.GraphKeys.LOCAL_VARIABLES]) + sm = session_manager.SessionManager( + local_init_op=v.initializer, + local_init_feed_dict={p: [1.0, 2.0, 3.0]}, + ready_op=variables.report_uninitialized_variables()) + sess = sm.prepare_session("") + self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) + @test_util.run_v1_only("b/120545219") def testPrepareSessionFails(self): checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session") diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-scaffold.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-scaffold.pbtxt index 38cc98b48e7..028e4a3f2bb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-scaffold.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-scaffold.pbtxt @@ -14,6 +14,10 @@ tf_class { name: "init_op" mtype: "" } + member { + name: "local_init_feed_dict" + mtype: "" + } member { name: "local_init_op" mtype: "" @@ -36,7 +40,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\', \'local_init_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "default_local_init_op" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-session-manager.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-session-manager.pbtxt index 448764fe081..4335b30775d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-session-manager.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-session-manager.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], " + argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\', \'local_init_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\', \'None\'], " } member_method { name: "prepare_session"