Add SavedModelEstimator support for saved models with assets.
The main op execution in SavedModelEstimator did not provide a feed_dict to rewrite asset tensors with their correct filepaths. This change fixes SavedModelEstimator to mirror the logic in SavedModelLoader.run_init_ops. PiperOrigin-RevId: 270822926
This commit is contained in:
parent
af39a076c2
commit
8d7f9e8328
@ -113,7 +113,8 @@ class Scaffold(object):
|
||||
local_init_op=None,
|
||||
summary_op=None,
|
||||
saver=None,
|
||||
copy_from_scaffold=None):
|
||||
copy_from_scaffold=None,
|
||||
local_init_feed_dict=None):
|
||||
"""Create a scaffold.
|
||||
|
||||
Args:
|
||||
@ -146,6 +147,8 @@ class Scaffold(object):
|
||||
match the variable through the other `Model`.
|
||||
copy_from_scaffold: Optional scaffold object to copy fields from. Its
|
||||
fields will be overwritten by the provided fields in this function.
|
||||
local_init_feed_dict: Optional session feed dictionary to use when running
|
||||
the local_init_op.
|
||||
"""
|
||||
if copy_from_scaffold is not None:
|
||||
if not isinstance(copy_from_scaffold, Scaffold):
|
||||
@ -162,6 +165,8 @@ class Scaffold(object):
|
||||
ready_for_local_init_op = coalesce(
|
||||
ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
|
||||
local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
|
||||
local_init_feed_dict = coalesce(local_init_feed_dict,
|
||||
copy_from_scaffold.local_init_feed_dict)
|
||||
summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
|
||||
saver = coalesce(saver, copy_from_scaffold.saver)
|
||||
|
||||
@ -178,6 +183,7 @@ class Scaffold(object):
|
||||
self._ready_op = ready_op
|
||||
self._ready_for_local_init_op = ready_for_local_init_op
|
||||
self._local_init_op = local_init_op
|
||||
self._local_init_feed_dict = local_init_feed_dict
|
||||
self._summary_op = summary_op
|
||||
self._saver = saver
|
||||
|
||||
@ -260,6 +266,10 @@ class Scaffold(object):
|
||||
def local_init_op(self):
|
||||
return self._local_init_op
|
||||
|
||||
@property
|
||||
def local_init_feed_dict(self):
|
||||
return self._local_init_feed_dict
|
||||
|
||||
@property
|
||||
def summary_op(self):
|
||||
return self._summary_op
|
||||
@ -624,11 +634,13 @@ class ChiefSessionCreator(SessionCreator):
|
||||
self._config = config
|
||||
|
||||
def _get_session_manager(self):
|
||||
"""Gets or creates a SessionManager."""
|
||||
if self._session_manager:
|
||||
return self._session_manager
|
||||
|
||||
self._session_manager = sm.SessionManager(
|
||||
local_init_op=self._scaffold.local_init_op,
|
||||
local_init_feed_dict=self._scaffold.local_init_feed_dict,
|
||||
ready_op=self._scaffold.ready_op,
|
||||
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
|
||||
graph=ops.get_default_graph())
|
||||
@ -672,11 +684,13 @@ class WorkerSessionCreator(SessionCreator):
|
||||
self._max_wait_secs = max_wait_secs
|
||||
|
||||
def _get_session_manager(self):
|
||||
"""Gets or creates a SessionManager."""
|
||||
if self._session_manager:
|
||||
return self._session_manager
|
||||
|
||||
self._session_manager = sm.SessionManager(
|
||||
local_init_op=self._scaffold.local_init_op,
|
||||
local_init_feed_dict=self._scaffold.local_init_feed_dict,
|
||||
ready_op=self._scaffold.ready_op,
|
||||
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
|
||||
graph=ops.get_default_graph())
|
||||
|
@ -88,6 +88,7 @@ class ScaffoldTest(test.TestCase):
|
||||
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
|
||||
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
|
||||
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
|
||||
self.assertEqual(None, scaffold.local_init_feed_dict)
|
||||
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
|
||||
with self.cached_session() as sess:
|
||||
self.assertItemsEqual([b'my_var', b'my_local_var'],
|
||||
@ -110,6 +111,7 @@ class ScaffoldTest(test.TestCase):
|
||||
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
|
||||
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
|
||||
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
|
||||
self.assertEqual(None, scaffold.local_init_feed_dict)
|
||||
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
|
||||
|
||||
def test_caches_values(self):
|
||||
@ -145,6 +147,7 @@ class ScaffoldTest(test.TestCase):
|
||||
ready_op=5,
|
||||
ready_for_local_init_op=6,
|
||||
local_init_op=7,
|
||||
local_init_feed_dict=8,
|
||||
saver=saver)
|
||||
scaffold.finalize()
|
||||
self.assertEqual(2, scaffold.init_op)
|
||||
@ -153,6 +156,7 @@ class ScaffoldTest(test.TestCase):
|
||||
self.assertEqual(5, scaffold.ready_op)
|
||||
self.assertEqual(6, scaffold.ready_for_local_init_op)
|
||||
self.assertEqual(7, scaffold.local_init_op)
|
||||
self.assertEqual(8, scaffold.local_init_feed_dict)
|
||||
self.assertEqual(saver, scaffold.saver)
|
||||
|
||||
def test_graph_is_finalized(self):
|
||||
@ -175,6 +179,7 @@ class ScaffoldTest(test.TestCase):
|
||||
ready_op=5,
|
||||
ready_for_local_init_op=6,
|
||||
local_init_op=7,
|
||||
local_init_feed_dict=8,
|
||||
saver=saver,
|
||||
copy_from_scaffold=scaffold1)
|
||||
|
||||
@ -185,6 +190,7 @@ class ScaffoldTest(test.TestCase):
|
||||
self.assertEqual(5, scaffold2.ready_op)
|
||||
self.assertEqual(6, scaffold2.ready_for_local_init_op)
|
||||
self.assertEqual(7, scaffold2.local_init_op)
|
||||
self.assertEqual(8, scaffold2.local_init_feed_dict)
|
||||
self.assertEqual(saver, scaffold2.saver)
|
||||
|
||||
def test_new_scaffold_from_existing_scaffold(self):
|
||||
@ -198,6 +204,7 @@ class ScaffoldTest(test.TestCase):
|
||||
ready_op=5,
|
||||
ready_for_local_init_op=6,
|
||||
local_init_op=7,
|
||||
local_init_feed_dict=8,
|
||||
saver=saver)
|
||||
|
||||
scaffold2 = monitored_session.Scaffold(
|
||||
@ -207,6 +214,7 @@ class ScaffoldTest(test.TestCase):
|
||||
ready_op=10,
|
||||
ready_for_local_init_op=12,
|
||||
local_init_op=14,
|
||||
local_init_feed_dict=15,
|
||||
saver=saver,
|
||||
copy_from_scaffold=scaffold1)
|
||||
|
||||
@ -217,6 +225,7 @@ class ScaffoldTest(test.TestCase):
|
||||
self.assertEqual(10, scaffold2.ready_op)
|
||||
self.assertEqual(12, scaffold2.ready_for_local_init_op)
|
||||
self.assertEqual(14, scaffold2.local_init_op)
|
||||
self.assertEqual(15, scaffold2.local_init_feed_dict)
|
||||
self.assertEqual(saver, scaffold2.saver)
|
||||
|
||||
def test_copy_from_scaffold_is_scaffold(self):
|
||||
|
@ -97,7 +97,8 @@ class SessionManager(object):
|
||||
ready_for_local_init_op=None,
|
||||
graph=None,
|
||||
recovery_wait_secs=30,
|
||||
local_init_run_options=None):
|
||||
local_init_run_options=None,
|
||||
local_init_feed_dict=None):
|
||||
"""Creates a SessionManager.
|
||||
|
||||
The `local_init_op` is an `Operation` that is run always after a new session
|
||||
@ -131,6 +132,8 @@ class SessionManager(object):
|
||||
recovery_wait_secs: Seconds between checks for the model to be ready.
|
||||
local_init_run_options: RunOptions to be passed to session.run when
|
||||
executing the local_init_op.
|
||||
local_init_feed_dict: Optional session feed dictionary to use when running
|
||||
the local_init_op.
|
||||
|
||||
Raises:
|
||||
ValueError: If ready_for_local_init_op is not None but local_init_op is
|
||||
@ -146,6 +149,7 @@ class SessionManager(object):
|
||||
self._recovery_wait_secs = recovery_wait_secs
|
||||
self._target = None
|
||||
self._local_init_run_options = local_init_run_options
|
||||
self._local_init_feed_dict = local_init_feed_dict
|
||||
if ready_for_local_init_op is not None and local_init_op is None:
|
||||
raise ValueError("If you pass a ready_for_local_init_op "
|
||||
"you must also pass a local_init_op "
|
||||
@ -498,7 +502,8 @@ class SessionManager(object):
|
||||
is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
|
||||
if is_ready_for_local_init:
|
||||
logging.info("Running local_init_op.")
|
||||
sess.run(self._local_init_op, options=self._local_init_run_options)
|
||||
sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict,
|
||||
options=self._local_init_run_options)
|
||||
logging.info("Done running local_init_op.")
|
||||
return True, None
|
||||
else:
|
||||
|
@ -69,6 +69,18 @@ class SessionManagerTest(test.TestCase):
|
||||
"", init_fn=lambda sess: sess.run(v.initializer))
|
||||
self.assertAllClose([125], sess.run(v))
|
||||
|
||||
def testPrepareSessionSucceedsWithLocalInitFeedDict(self):
|
||||
with ops.Graph().as_default():
|
||||
p = array_ops.placeholder(dtypes.float32, shape=(3,))
|
||||
v = variables.VariableV1(p, name="v",
|
||||
collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
||||
sm = session_manager.SessionManager(
|
||||
local_init_op=v.initializer,
|
||||
local_init_feed_dict={p: [1.0, 2.0, 3.0]},
|
||||
ready_op=variables.report_uninitialized_variables())
|
||||
sess = sm.prepare_session("")
|
||||
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testPrepareSessionFails(self):
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
|
||||
|
@ -14,6 +14,10 @@ tf_class {
|
||||
name: "init_op"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "local_init_feed_dict"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "local_init_op"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -36,7 +40,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\', \'local_init_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "default_local_init_op"
|
||||
|
@ -4,7 +4,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\', \'local_init_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "prepare_session"
|
||||
|
Loading…
Reference in New Issue
Block a user