Add recreated assets to ASSET_FILEPATHS collection if invoked in non-eager context.
PiperOrigin-RevId: 347084961 Change-Id: I3b4ed01a047b04e04f01a0aad140f24e25e78b2c
This commit is contained in:
parent
d6eacc2cd3
commit
6d8366afc4
@ -573,7 +573,10 @@ class Loader(object):
|
|||||||
filename = os.path.join(
|
filename = os.path.join(
|
||||||
saved_model_utils.get_assets_dir(self._export_dir),
|
saved_model_utils.get_assets_dir(self._export_dir),
|
||||||
self._asset_file_def[proto.asset_file_def_index].filename)
|
self._asset_file_def[proto.asset_file_def_index].filename)
|
||||||
return tracking.Asset(filename), setattr
|
asset = tracking.Asset(filename)
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset.asset_path)
|
||||||
|
return asset, setattr
|
||||||
|
|
||||||
def _recreate_function(self, proto):
|
def _recreate_function(self, proto):
|
||||||
return function_deserialization.recreate_function(
|
return function_deserialization.recreate_function(
|
||||||
|
@ -292,6 +292,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
|||||||
imported_tensor = imported.f()
|
imported_tensor = imported.f()
|
||||||
with monitored_session.MonitoredSession() as sess:
|
with monitored_session.MonitoredSession() as sess:
|
||||||
imported_output = sess.run(imported_tensor)
|
imported_output = sess.run(imported_tensor)
|
||||||
|
self.assertLen(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), 1)
|
||||||
self.assertNotEqual(original_output, imported_output)
|
self.assertNotEqual(original_output, imported_output)
|
||||||
with open(imported_output, "r") as f:
|
with open(imported_output, "r") as f:
|
||||||
self.assertEqual("contents", f.read())
|
self.assertEqual("contents", f.read())
|
||||||
|
Loading…
Reference in New Issue
Block a user