Add SavedModelEstimator support for saved models with assets.
The main op execution in SavedModelEstimator did not provide a feed_dict to rewrite asset tensors with their correct filepaths. This change fixes SavedModelEstimator to mirror the logic in SavedModelLoader.run_init_ops. PiperOrigin-RevId: 270822926
This commit is contained in:
parent
af39a076c2
commit
8d7f9e8328
tensorflow
@ -113,7 +113,8 @@ class Scaffold(object):
|
|||||||
local_init_op=None,
|
local_init_op=None,
|
||||||
summary_op=None,
|
summary_op=None,
|
||||||
saver=None,
|
saver=None,
|
||||||
copy_from_scaffold=None):
|
copy_from_scaffold=None,
|
||||||
|
local_init_feed_dict=None):
|
||||||
"""Create a scaffold.
|
"""Create a scaffold.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -146,6 +147,8 @@ class Scaffold(object):
|
|||||||
match the variable through the other `Model`.
|
match the variable through the other `Model`.
|
||||||
copy_from_scaffold: Optional scaffold object to copy fields from. Its
|
copy_from_scaffold: Optional scaffold object to copy fields from. Its
|
||||||
fields will be overwritten by the provided fields in this function.
|
fields will be overwritten by the provided fields in this function.
|
||||||
|
local_init_feed_dict: Optional session feed dictionary to use when running
|
||||||
|
the local_init_op.
|
||||||
"""
|
"""
|
||||||
if copy_from_scaffold is not None:
|
if copy_from_scaffold is not None:
|
||||||
if not isinstance(copy_from_scaffold, Scaffold):
|
if not isinstance(copy_from_scaffold, Scaffold):
|
||||||
@ -162,6 +165,8 @@ class Scaffold(object):
|
|||||||
ready_for_local_init_op = coalesce(
|
ready_for_local_init_op = coalesce(
|
||||||
ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
|
ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
|
||||||
local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
|
local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
|
||||||
|
local_init_feed_dict = coalesce(local_init_feed_dict,
|
||||||
|
copy_from_scaffold.local_init_feed_dict)
|
||||||
summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
|
summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
|
||||||
saver = coalesce(saver, copy_from_scaffold.saver)
|
saver = coalesce(saver, copy_from_scaffold.saver)
|
||||||
|
|
||||||
@ -178,6 +183,7 @@ class Scaffold(object):
|
|||||||
self._ready_op = ready_op
|
self._ready_op = ready_op
|
||||||
self._ready_for_local_init_op = ready_for_local_init_op
|
self._ready_for_local_init_op = ready_for_local_init_op
|
||||||
self._local_init_op = local_init_op
|
self._local_init_op = local_init_op
|
||||||
|
self._local_init_feed_dict = local_init_feed_dict
|
||||||
self._summary_op = summary_op
|
self._summary_op = summary_op
|
||||||
self._saver = saver
|
self._saver = saver
|
||||||
|
|
||||||
@ -260,6 +266,10 @@ class Scaffold(object):
|
|||||||
def local_init_op(self):
|
def local_init_op(self):
|
||||||
return self._local_init_op
|
return self._local_init_op
|
||||||
|
|
||||||
|
@property
|
||||||
|
def local_init_feed_dict(self):
|
||||||
|
return self._local_init_feed_dict
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def summary_op(self):
|
def summary_op(self):
|
||||||
return self._summary_op
|
return self._summary_op
|
||||||
@ -624,11 +634,13 @@ class ChiefSessionCreator(SessionCreator):
|
|||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
def _get_session_manager(self):
|
def _get_session_manager(self):
|
||||||
|
"""Gets or creates a SessionManager."""
|
||||||
if self._session_manager:
|
if self._session_manager:
|
||||||
return self._session_manager
|
return self._session_manager
|
||||||
|
|
||||||
self._session_manager = sm.SessionManager(
|
self._session_manager = sm.SessionManager(
|
||||||
local_init_op=self._scaffold.local_init_op,
|
local_init_op=self._scaffold.local_init_op,
|
||||||
|
local_init_feed_dict=self._scaffold.local_init_feed_dict,
|
||||||
ready_op=self._scaffold.ready_op,
|
ready_op=self._scaffold.ready_op,
|
||||||
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
|
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
|
||||||
graph=ops.get_default_graph())
|
graph=ops.get_default_graph())
|
||||||
@ -672,11 +684,13 @@ class WorkerSessionCreator(SessionCreator):
|
|||||||
self._max_wait_secs = max_wait_secs
|
self._max_wait_secs = max_wait_secs
|
||||||
|
|
||||||
def _get_session_manager(self):
|
def _get_session_manager(self):
|
||||||
|
"""Gets or creates a SessionManager."""
|
||||||
if self._session_manager:
|
if self._session_manager:
|
||||||
return self._session_manager
|
return self._session_manager
|
||||||
|
|
||||||
self._session_manager = sm.SessionManager(
|
self._session_manager = sm.SessionManager(
|
||||||
local_init_op=self._scaffold.local_init_op,
|
local_init_op=self._scaffold.local_init_op,
|
||||||
|
local_init_feed_dict=self._scaffold.local_init_feed_dict,
|
||||||
ready_op=self._scaffold.ready_op,
|
ready_op=self._scaffold.ready_op,
|
||||||
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
|
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
|
||||||
graph=ops.get_default_graph())
|
graph=ops.get_default_graph())
|
||||||
|
@ -88,6 +88,7 @@ class ScaffoldTest(test.TestCase):
|
|||||||
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
|
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
|
||||||
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
|
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
|
||||||
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
|
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
|
||||||
|
self.assertEqual(None, scaffold.local_init_feed_dict)
|
||||||
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
|
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertItemsEqual([b'my_var', b'my_local_var'],
|
self.assertItemsEqual([b'my_var', b'my_local_var'],
|
||||||
@ -110,6 +111,7 @@ class ScaffoldTest(test.TestCase):
|
|||||||
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
|
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
|
||||||
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
|
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
|
||||||
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
|
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
|
||||||
|
self.assertEqual(None, scaffold.local_init_feed_dict)
|
||||||
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
|
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
|
||||||
|
|
||||||
def test_caches_values(self):
|
def test_caches_values(self):
|
||||||
@ -145,6 +147,7 @@ class ScaffoldTest(test.TestCase):
|
|||||||
ready_op=5,
|
ready_op=5,
|
||||||
ready_for_local_init_op=6,
|
ready_for_local_init_op=6,
|
||||||
local_init_op=7,
|
local_init_op=7,
|
||||||
|
local_init_feed_dict=8,
|
||||||
saver=saver)
|
saver=saver)
|
||||||
scaffold.finalize()
|
scaffold.finalize()
|
||||||
self.assertEqual(2, scaffold.init_op)
|
self.assertEqual(2, scaffold.init_op)
|
||||||
@ -153,6 +156,7 @@ class ScaffoldTest(test.TestCase):
|
|||||||
self.assertEqual(5, scaffold.ready_op)
|
self.assertEqual(5, scaffold.ready_op)
|
||||||
self.assertEqual(6, scaffold.ready_for_local_init_op)
|
self.assertEqual(6, scaffold.ready_for_local_init_op)
|
||||||
self.assertEqual(7, scaffold.local_init_op)
|
self.assertEqual(7, scaffold.local_init_op)
|
||||||
|
self.assertEqual(8, scaffold.local_init_feed_dict)
|
||||||
self.assertEqual(saver, scaffold.saver)
|
self.assertEqual(saver, scaffold.saver)
|
||||||
|
|
||||||
def test_graph_is_finalized(self):
|
def test_graph_is_finalized(self):
|
||||||
@ -175,6 +179,7 @@ class ScaffoldTest(test.TestCase):
|
|||||||
ready_op=5,
|
ready_op=5,
|
||||||
ready_for_local_init_op=6,
|
ready_for_local_init_op=6,
|
||||||
local_init_op=7,
|
local_init_op=7,
|
||||||
|
local_init_feed_dict=8,
|
||||||
saver=saver,
|
saver=saver,
|
||||||
copy_from_scaffold=scaffold1)
|
copy_from_scaffold=scaffold1)
|
||||||
|
|
||||||
@ -185,6 +190,7 @@ class ScaffoldTest(test.TestCase):
|
|||||||
self.assertEqual(5, scaffold2.ready_op)
|
self.assertEqual(5, scaffold2.ready_op)
|
||||||
self.assertEqual(6, scaffold2.ready_for_local_init_op)
|
self.assertEqual(6, scaffold2.ready_for_local_init_op)
|
||||||
self.assertEqual(7, scaffold2.local_init_op)
|
self.assertEqual(7, scaffold2.local_init_op)
|
||||||
|
self.assertEqual(8, scaffold2.local_init_feed_dict)
|
||||||
self.assertEqual(saver, scaffold2.saver)
|
self.assertEqual(saver, scaffold2.saver)
|
||||||
|
|
||||||
def test_new_scaffold_from_existing_scaffold(self):
|
def test_new_scaffold_from_existing_scaffold(self):
|
||||||
@ -198,6 +204,7 @@ class ScaffoldTest(test.TestCase):
|
|||||||
ready_op=5,
|
ready_op=5,
|
||||||
ready_for_local_init_op=6,
|
ready_for_local_init_op=6,
|
||||||
local_init_op=7,
|
local_init_op=7,
|
||||||
|
local_init_feed_dict=8,
|
||||||
saver=saver)
|
saver=saver)
|
||||||
|
|
||||||
scaffold2 = monitored_session.Scaffold(
|
scaffold2 = monitored_session.Scaffold(
|
||||||
@ -207,6 +214,7 @@ class ScaffoldTest(test.TestCase):
|
|||||||
ready_op=10,
|
ready_op=10,
|
||||||
ready_for_local_init_op=12,
|
ready_for_local_init_op=12,
|
||||||
local_init_op=14,
|
local_init_op=14,
|
||||||
|
local_init_feed_dict=15,
|
||||||
saver=saver,
|
saver=saver,
|
||||||
copy_from_scaffold=scaffold1)
|
copy_from_scaffold=scaffold1)
|
||||||
|
|
||||||
@ -217,6 +225,7 @@ class ScaffoldTest(test.TestCase):
|
|||||||
self.assertEqual(10, scaffold2.ready_op)
|
self.assertEqual(10, scaffold2.ready_op)
|
||||||
self.assertEqual(12, scaffold2.ready_for_local_init_op)
|
self.assertEqual(12, scaffold2.ready_for_local_init_op)
|
||||||
self.assertEqual(14, scaffold2.local_init_op)
|
self.assertEqual(14, scaffold2.local_init_op)
|
||||||
|
self.assertEqual(15, scaffold2.local_init_feed_dict)
|
||||||
self.assertEqual(saver, scaffold2.saver)
|
self.assertEqual(saver, scaffold2.saver)
|
||||||
|
|
||||||
def test_copy_from_scaffold_is_scaffold(self):
|
def test_copy_from_scaffold_is_scaffold(self):
|
||||||
|
@ -97,7 +97,8 @@ class SessionManager(object):
|
|||||||
ready_for_local_init_op=None,
|
ready_for_local_init_op=None,
|
||||||
graph=None,
|
graph=None,
|
||||||
recovery_wait_secs=30,
|
recovery_wait_secs=30,
|
||||||
local_init_run_options=None):
|
local_init_run_options=None,
|
||||||
|
local_init_feed_dict=None):
|
||||||
"""Creates a SessionManager.
|
"""Creates a SessionManager.
|
||||||
|
|
||||||
The `local_init_op` is an `Operation` that is run always after a new session
|
The `local_init_op` is an `Operation` that is run always after a new session
|
||||||
@ -131,6 +132,8 @@ class SessionManager(object):
|
|||||||
recovery_wait_secs: Seconds between checks for the model to be ready.
|
recovery_wait_secs: Seconds between checks for the model to be ready.
|
||||||
local_init_run_options: RunOptions to be passed to session.run when
|
local_init_run_options: RunOptions to be passed to session.run when
|
||||||
executing the local_init_op.
|
executing the local_init_op.
|
||||||
|
local_init_feed_dict: Optional session feed dictionary to use when running
|
||||||
|
the local_init_op.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If ready_for_local_init_op is not None but local_init_op is
|
ValueError: If ready_for_local_init_op is not None but local_init_op is
|
||||||
@ -146,6 +149,7 @@ class SessionManager(object):
|
|||||||
self._recovery_wait_secs = recovery_wait_secs
|
self._recovery_wait_secs = recovery_wait_secs
|
||||||
self._target = None
|
self._target = None
|
||||||
self._local_init_run_options = local_init_run_options
|
self._local_init_run_options = local_init_run_options
|
||||||
|
self._local_init_feed_dict = local_init_feed_dict
|
||||||
if ready_for_local_init_op is not None and local_init_op is None:
|
if ready_for_local_init_op is not None and local_init_op is None:
|
||||||
raise ValueError("If you pass a ready_for_local_init_op "
|
raise ValueError("If you pass a ready_for_local_init_op "
|
||||||
"you must also pass a local_init_op "
|
"you must also pass a local_init_op "
|
||||||
@ -498,7 +502,8 @@ class SessionManager(object):
|
|||||||
is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
|
is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
|
||||||
if is_ready_for_local_init:
|
if is_ready_for_local_init:
|
||||||
logging.info("Running local_init_op.")
|
logging.info("Running local_init_op.")
|
||||||
sess.run(self._local_init_op, options=self._local_init_run_options)
|
sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict,
|
||||||
|
options=self._local_init_run_options)
|
||||||
logging.info("Done running local_init_op.")
|
logging.info("Done running local_init_op.")
|
||||||
return True, None
|
return True, None
|
||||||
else:
|
else:
|
||||||
|
@ -69,6 +69,18 @@ class SessionManagerTest(test.TestCase):
|
|||||||
"", init_fn=lambda sess: sess.run(v.initializer))
|
"", init_fn=lambda sess: sess.run(v.initializer))
|
||||||
self.assertAllClose([125], sess.run(v))
|
self.assertAllClose([125], sess.run(v))
|
||||||
|
|
||||||
|
def testPrepareSessionSucceedsWithLocalInitFeedDict(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
p = array_ops.placeholder(dtypes.float32, shape=(3,))
|
||||||
|
v = variables.VariableV1(p, name="v",
|
||||||
|
collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
||||||
|
sm = session_manager.SessionManager(
|
||||||
|
local_init_op=v.initializer,
|
||||||
|
local_init_feed_dict={p: [1.0, 2.0, 3.0]},
|
||||||
|
ready_op=variables.report_uninitialized_variables())
|
||||||
|
sess = sm.prepare_session("")
|
||||||
|
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testPrepareSessionFails(self):
|
def testPrepareSessionFails(self):
|
||||||
checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
|
checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
|
||||||
|
@ -14,6 +14,10 @@ tf_class {
|
|||||||
name: "init_op"
|
name: "init_op"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "local_init_feed_dict"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "local_init_op"
|
name: "local_init_op"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -36,7 +40,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\', \'local_init_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "default_local_init_op"
|
name: "default_local_init_op"
|
||||||
|
@ -4,7 +4,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], "
|
argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\', \'local_init_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "prepare_session"
|
name: "prepare_session"
|
||||||
|
Loading…
Reference in New Issue
Block a user