Add SavedModelEstimator support for saved models with assets.

The main op execution in SavedModelEstimator did not provide a feed_dict to rewrite asset tensors with their correct filepaths. This change fixes SavedModelEstimator to mirror the logic in SavedModelLoader.run_init_ops.

PiperOrigin-RevId: 270822926
This commit is contained in:
RJ Skerry-Ryan 2019-09-23 20:58:06 -07:00 committed by TensorFlower Gardener
parent af39a076c2
commit 8d7f9e8328
6 changed files with 49 additions and 5 deletions

View File

@ -113,7 +113,8 @@ class Scaffold(object):
local_init_op=None, local_init_op=None,
summary_op=None, summary_op=None,
saver=None, saver=None,
copy_from_scaffold=None): copy_from_scaffold=None,
local_init_feed_dict=None):
"""Create a scaffold. """Create a scaffold.
Args: Args:
@ -146,6 +147,8 @@ class Scaffold(object):
match the variable through the other `Model`. match the variable through the other `Model`.
copy_from_scaffold: Optional scaffold object to copy fields from. Its copy_from_scaffold: Optional scaffold object to copy fields from. Its
fields will be overwritten by the provided fields in this function. fields will be overwritten by the provided fields in this function.
local_init_feed_dict: Optional session feed dictionary to use when running
the local_init_op.
""" """
if copy_from_scaffold is not None: if copy_from_scaffold is not None:
if not isinstance(copy_from_scaffold, Scaffold): if not isinstance(copy_from_scaffold, Scaffold):
@ -162,6 +165,8 @@ class Scaffold(object):
ready_for_local_init_op = coalesce( ready_for_local_init_op = coalesce(
ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op) ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op) local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
local_init_feed_dict = coalesce(local_init_feed_dict,
copy_from_scaffold.local_init_feed_dict)
summary_op = coalesce(summary_op, copy_from_scaffold.summary_op) summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
saver = coalesce(saver, copy_from_scaffold.saver) saver = coalesce(saver, copy_from_scaffold.saver)
@ -178,6 +183,7 @@ class Scaffold(object):
self._ready_op = ready_op self._ready_op = ready_op
self._ready_for_local_init_op = ready_for_local_init_op self._ready_for_local_init_op = ready_for_local_init_op
self._local_init_op = local_init_op self._local_init_op = local_init_op
self._local_init_feed_dict = local_init_feed_dict
self._summary_op = summary_op self._summary_op = summary_op
self._saver = saver self._saver = saver
@ -260,6 +266,10 @@ class Scaffold(object):
def local_init_op(self): def local_init_op(self):
return self._local_init_op return self._local_init_op
@property
def local_init_feed_dict(self):
return self._local_init_feed_dict
@property @property
def summary_op(self): def summary_op(self):
return self._summary_op return self._summary_op
@ -624,11 +634,13 @@ class ChiefSessionCreator(SessionCreator):
self._config = config self._config = config
def _get_session_manager(self): def _get_session_manager(self):
"""Gets or creates a SessionManager."""
if self._session_manager: if self._session_manager:
return self._session_manager return self._session_manager
self._session_manager = sm.SessionManager( self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op, local_init_op=self._scaffold.local_init_op,
local_init_feed_dict=self._scaffold.local_init_feed_dict,
ready_op=self._scaffold.ready_op, ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph()) graph=ops.get_default_graph())
@ -672,11 +684,13 @@ class WorkerSessionCreator(SessionCreator):
self._max_wait_secs = max_wait_secs self._max_wait_secs = max_wait_secs
def _get_session_manager(self): def _get_session_manager(self):
"""Gets or creates a SessionManager."""
if self._session_manager: if self._session_manager:
return self._session_manager return self._session_manager
self._session_manager = sm.SessionManager( self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op, local_init_op=self._scaffold.local_init_op,
local_init_feed_dict=self._scaffold.local_init_feed_dict,
ready_op=self._scaffold.ready_op, ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph()) graph=ops.get_default_graph())

View File

@ -88,6 +88,7 @@ class ScaffoldTest(test.TestCase):
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor)) self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor)) self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation)) self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
self.assertEqual(None, scaffold.local_init_feed_dict)
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver)) self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
with self.cached_session() as sess: with self.cached_session() as sess:
self.assertItemsEqual([b'my_var', b'my_local_var'], self.assertItemsEqual([b'my_var', b'my_local_var'],
@ -110,6 +111,7 @@ class ScaffoldTest(test.TestCase):
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor)) self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor)) self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation)) self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
self.assertEqual(None, scaffold.local_init_feed_dict)
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver)) self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
def test_caches_values(self): def test_caches_values(self):
@ -145,6 +147,7 @@ class ScaffoldTest(test.TestCase):
ready_op=5, ready_op=5,
ready_for_local_init_op=6, ready_for_local_init_op=6,
local_init_op=7, local_init_op=7,
local_init_feed_dict=8,
saver=saver) saver=saver)
scaffold.finalize() scaffold.finalize()
self.assertEqual(2, scaffold.init_op) self.assertEqual(2, scaffold.init_op)
@ -153,6 +156,7 @@ class ScaffoldTest(test.TestCase):
self.assertEqual(5, scaffold.ready_op) self.assertEqual(5, scaffold.ready_op)
self.assertEqual(6, scaffold.ready_for_local_init_op) self.assertEqual(6, scaffold.ready_for_local_init_op)
self.assertEqual(7, scaffold.local_init_op) self.assertEqual(7, scaffold.local_init_op)
self.assertEqual(8, scaffold.local_init_feed_dict)
self.assertEqual(saver, scaffold.saver) self.assertEqual(saver, scaffold.saver)
def test_graph_is_finalized(self): def test_graph_is_finalized(self):
@ -175,6 +179,7 @@ class ScaffoldTest(test.TestCase):
ready_op=5, ready_op=5,
ready_for_local_init_op=6, ready_for_local_init_op=6,
local_init_op=7, local_init_op=7,
local_init_feed_dict=8,
saver=saver, saver=saver,
copy_from_scaffold=scaffold1) copy_from_scaffold=scaffold1)
@ -185,6 +190,7 @@ class ScaffoldTest(test.TestCase):
self.assertEqual(5, scaffold2.ready_op) self.assertEqual(5, scaffold2.ready_op)
self.assertEqual(6, scaffold2.ready_for_local_init_op) self.assertEqual(6, scaffold2.ready_for_local_init_op)
self.assertEqual(7, scaffold2.local_init_op) self.assertEqual(7, scaffold2.local_init_op)
self.assertEqual(8, scaffold2.local_init_feed_dict)
self.assertEqual(saver, scaffold2.saver) self.assertEqual(saver, scaffold2.saver)
def test_new_scaffold_from_existing_scaffold(self): def test_new_scaffold_from_existing_scaffold(self):
@ -198,6 +204,7 @@ class ScaffoldTest(test.TestCase):
ready_op=5, ready_op=5,
ready_for_local_init_op=6, ready_for_local_init_op=6,
local_init_op=7, local_init_op=7,
local_init_feed_dict=8,
saver=saver) saver=saver)
scaffold2 = monitored_session.Scaffold( scaffold2 = monitored_session.Scaffold(
@ -207,6 +214,7 @@ class ScaffoldTest(test.TestCase):
ready_op=10, ready_op=10,
ready_for_local_init_op=12, ready_for_local_init_op=12,
local_init_op=14, local_init_op=14,
local_init_feed_dict=15,
saver=saver, saver=saver,
copy_from_scaffold=scaffold1) copy_from_scaffold=scaffold1)
@ -217,6 +225,7 @@ class ScaffoldTest(test.TestCase):
self.assertEqual(10, scaffold2.ready_op) self.assertEqual(10, scaffold2.ready_op)
self.assertEqual(12, scaffold2.ready_for_local_init_op) self.assertEqual(12, scaffold2.ready_for_local_init_op)
self.assertEqual(14, scaffold2.local_init_op) self.assertEqual(14, scaffold2.local_init_op)
self.assertEqual(15, scaffold2.local_init_feed_dict)
self.assertEqual(saver, scaffold2.saver) self.assertEqual(saver, scaffold2.saver)
def test_copy_from_scaffold_is_scaffold(self): def test_copy_from_scaffold_is_scaffold(self):

View File

@ -97,7 +97,8 @@ class SessionManager(object):
ready_for_local_init_op=None, ready_for_local_init_op=None,
graph=None, graph=None,
recovery_wait_secs=30, recovery_wait_secs=30,
local_init_run_options=None): local_init_run_options=None,
local_init_feed_dict=None):
"""Creates a SessionManager. """Creates a SessionManager.
The `local_init_op` is an `Operation` that is run always after a new session The `local_init_op` is an `Operation` that is run always after a new session
@ -131,6 +132,8 @@ class SessionManager(object):
recovery_wait_secs: Seconds between checks for the model to be ready. recovery_wait_secs: Seconds between checks for the model to be ready.
local_init_run_options: RunOptions to be passed to session.run when local_init_run_options: RunOptions to be passed to session.run when
executing the local_init_op. executing the local_init_op.
local_init_feed_dict: Optional session feed dictionary to use when running
the local_init_op.
Raises: Raises:
ValueError: If ready_for_local_init_op is not None but local_init_op is ValueError: If ready_for_local_init_op is not None but local_init_op is
@ -146,6 +149,7 @@ class SessionManager(object):
self._recovery_wait_secs = recovery_wait_secs self._recovery_wait_secs = recovery_wait_secs
self._target = None self._target = None
self._local_init_run_options = local_init_run_options self._local_init_run_options = local_init_run_options
self._local_init_feed_dict = local_init_feed_dict
if ready_for_local_init_op is not None and local_init_op is None: if ready_for_local_init_op is not None and local_init_op is None:
raise ValueError("If you pass a ready_for_local_init_op " raise ValueError("If you pass a ready_for_local_init_op "
"you must also pass a local_init_op " "you must also pass a local_init_op "
@ -498,7 +502,8 @@ class SessionManager(object):
is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
if is_ready_for_local_init: if is_ready_for_local_init:
logging.info("Running local_init_op.") logging.info("Running local_init_op.")
sess.run(self._local_init_op, options=self._local_init_run_options) sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict,
options=self._local_init_run_options)
logging.info("Done running local_init_op.") logging.info("Done running local_init_op.")
return True, None return True, None
else: else:

View File

@ -69,6 +69,18 @@ class SessionManagerTest(test.TestCase):
"", init_fn=lambda sess: sess.run(v.initializer)) "", init_fn=lambda sess: sess.run(v.initializer))
self.assertAllClose([125], sess.run(v)) self.assertAllClose([125], sess.run(v))
def testPrepareSessionSucceedsWithLocalInitFeedDict(self):
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
v = variables.VariableV1(p, name="v",
collections=[ops.GraphKeys.LOCAL_VARIABLES])
sm = session_manager.SessionManager(
local_init_op=v.initializer,
local_init_feed_dict={p: [1.0, 2.0, 3.0]},
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testPrepareSessionFails(self): def testPrepareSessionFails(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session") checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")

View File

@ -14,6 +14,10 @@ tf_class {
name: "init_op" name: "init_op"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "local_init_feed_dict"
mtype: "<type \'property\'>"
}
member { member {
name: "local_init_op" name: "local_init_op"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
@ -36,7 +40,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\', \'local_init_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "default_local_init_op" name: "default_local_init_op"

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], " argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\', \'local_init_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "prepare_session" name: "prepare_session"