Add SavedModelEstimator support for saved models with assets.

The main op execution in SavedModelEstimator did not provide a feed_dict to rewrite asset tensors with their correct filepaths. This change fixes SavedModelEstimator to mirror the logic in SavedModelLoader.run_init_ops.

PiperOrigin-RevId: 270822926
This commit is contained in:
RJ Skerry-Ryan 2019-09-23 20:58:06 -07:00 committed by TensorFlower Gardener
parent af39a076c2
commit 8d7f9e8328
6 changed files with 49 additions and 5 deletions

View File

@ -113,7 +113,8 @@ class Scaffold(object):
local_init_op=None,
summary_op=None,
saver=None,
copy_from_scaffold=None):
copy_from_scaffold=None,
local_init_feed_dict=None):
"""Create a scaffold.
Args:
@ -146,6 +147,8 @@ class Scaffold(object):
match the variable through the other `Model`.
copy_from_scaffold: Optional scaffold object to copy fields from. Its
fields will be overwritten by the provided fields in this function.
local_init_feed_dict: Optional session feed dictionary to use when running
the local_init_op.
"""
if copy_from_scaffold is not None:
if not isinstance(copy_from_scaffold, Scaffold):
@ -162,6 +165,8 @@ class Scaffold(object):
ready_for_local_init_op = coalesce(
ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
local_init_feed_dict = coalesce(local_init_feed_dict,
copy_from_scaffold.local_init_feed_dict)
summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
saver = coalesce(saver, copy_from_scaffold.saver)
@ -178,6 +183,7 @@ class Scaffold(object):
self._ready_op = ready_op
self._ready_for_local_init_op = ready_for_local_init_op
self._local_init_op = local_init_op
self._local_init_feed_dict = local_init_feed_dict
self._summary_op = summary_op
self._saver = saver
@ -260,6 +266,10 @@ class Scaffold(object):
def local_init_op(self):
return self._local_init_op
@property
def local_init_feed_dict(self):
return self._local_init_feed_dict
@property
def summary_op(self):
return self._summary_op
@ -624,11 +634,13 @@ class ChiefSessionCreator(SessionCreator):
self._config = config
def _get_session_manager(self):
"""Gets or creates a SessionManager."""
if self._session_manager:
return self._session_manager
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
local_init_feed_dict=self._scaffold.local_init_feed_dict,
ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph())
@ -672,11 +684,13 @@ class WorkerSessionCreator(SessionCreator):
self._max_wait_secs = max_wait_secs
def _get_session_manager(self):
"""Gets or creates a SessionManager."""
if self._session_manager:
return self._session_manager
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
local_init_feed_dict=self._scaffold.local_init_feed_dict,
ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph())

View File

@ -88,6 +88,7 @@ class ScaffoldTest(test.TestCase):
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
self.assertEqual(None, scaffold.local_init_feed_dict)
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
with self.cached_session() as sess:
self.assertItemsEqual([b'my_var', b'my_local_var'],
@ -110,6 +111,7 @@ class ScaffoldTest(test.TestCase):
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
self.assertEqual(None, scaffold.local_init_feed_dict)
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
def test_caches_values(self):
@ -145,6 +147,7 @@ class ScaffoldTest(test.TestCase):
ready_op=5,
ready_for_local_init_op=6,
local_init_op=7,
local_init_feed_dict=8,
saver=saver)
scaffold.finalize()
self.assertEqual(2, scaffold.init_op)
@ -153,6 +156,7 @@ class ScaffoldTest(test.TestCase):
self.assertEqual(5, scaffold.ready_op)
self.assertEqual(6, scaffold.ready_for_local_init_op)
self.assertEqual(7, scaffold.local_init_op)
self.assertEqual(8, scaffold.local_init_feed_dict)
self.assertEqual(saver, scaffold.saver)
def test_graph_is_finalized(self):
@ -175,6 +179,7 @@ class ScaffoldTest(test.TestCase):
ready_op=5,
ready_for_local_init_op=6,
local_init_op=7,
local_init_feed_dict=8,
saver=saver,
copy_from_scaffold=scaffold1)
@ -185,6 +190,7 @@ class ScaffoldTest(test.TestCase):
self.assertEqual(5, scaffold2.ready_op)
self.assertEqual(6, scaffold2.ready_for_local_init_op)
self.assertEqual(7, scaffold2.local_init_op)
self.assertEqual(8, scaffold2.local_init_feed_dict)
self.assertEqual(saver, scaffold2.saver)
def test_new_scaffold_from_existing_scaffold(self):
@ -198,6 +204,7 @@ class ScaffoldTest(test.TestCase):
ready_op=5,
ready_for_local_init_op=6,
local_init_op=7,
local_init_feed_dict=8,
saver=saver)
scaffold2 = monitored_session.Scaffold(
@ -207,6 +214,7 @@ class ScaffoldTest(test.TestCase):
ready_op=10,
ready_for_local_init_op=12,
local_init_op=14,
local_init_feed_dict=15,
saver=saver,
copy_from_scaffold=scaffold1)
@ -217,6 +225,7 @@ class ScaffoldTest(test.TestCase):
self.assertEqual(10, scaffold2.ready_op)
self.assertEqual(12, scaffold2.ready_for_local_init_op)
self.assertEqual(14, scaffold2.local_init_op)
self.assertEqual(15, scaffold2.local_init_feed_dict)
self.assertEqual(saver, scaffold2.saver)
def test_copy_from_scaffold_is_scaffold(self):

View File

@ -97,7 +97,8 @@ class SessionManager(object):
ready_for_local_init_op=None,
graph=None,
recovery_wait_secs=30,
local_init_run_options=None):
local_init_run_options=None,
local_init_feed_dict=None):
"""Creates a SessionManager.
The `local_init_op` is an `Operation` that is run always after a new session
@ -131,6 +132,8 @@ class SessionManager(object):
recovery_wait_secs: Seconds between checks for the model to be ready.
local_init_run_options: RunOptions to be passed to session.run when
executing the local_init_op.
local_init_feed_dict: Optional session feed dictionary to use when running
the local_init_op.
Raises:
ValueError: If ready_for_local_init_op is not None but local_init_op is
@ -146,6 +149,7 @@ class SessionManager(object):
self._recovery_wait_secs = recovery_wait_secs
self._target = None
self._local_init_run_options = local_init_run_options
self._local_init_feed_dict = local_init_feed_dict
if ready_for_local_init_op is not None and local_init_op is None:
raise ValueError("If you pass a ready_for_local_init_op "
"you must also pass a local_init_op "
@ -498,7 +502,8 @@ class SessionManager(object):
is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
if is_ready_for_local_init:
logging.info("Running local_init_op.")
sess.run(self._local_init_op, options=self._local_init_run_options)
sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict,
options=self._local_init_run_options)
logging.info("Done running local_init_op.")
return True, None
else:

View File

@ -69,6 +69,18 @@ class SessionManagerTest(test.TestCase):
"", init_fn=lambda sess: sess.run(v.initializer))
self.assertAllClose([125], sess.run(v))
def testPrepareSessionSucceedsWithLocalInitFeedDict(self):
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
v = variables.VariableV1(p, name="v",
collections=[ops.GraphKeys.LOCAL_VARIABLES])
sm = session_manager.SessionManager(
local_init_op=v.initializer,
local_init_feed_dict={p: [1.0, 2.0, 3.0]},
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
@test_util.run_v1_only("b/120545219")
def testPrepareSessionFails(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")

View File

@ -14,6 +14,10 @@ tf_class {
name: "init_op"
mtype: "<type \'property\'>"
}
member {
name: "local_init_feed_dict"
mtype: "<type \'property\'>"
}
member {
name: "local_init_op"
mtype: "<type \'property\'>"
@ -36,7 +40,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\', \'local_init_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "default_local_init_op"

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], "
argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\', \'local_init_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\', \'None\'], "
}
member_method {
name: "prepare_session"