Add recreated assets to ASSET_FILEPATHS collection if invoked in non-eager context.

PiperOrigin-RevId: 347084961
Change-Id: I3b4ed01a047b04e04f01a0aad140f24e25e78b2c
This commit is contained in:
A. Unique TensorFlower 2020-12-11 15:16:38 -08:00 committed by TensorFlower Gardener
parent d6eacc2cd3
commit 6d8366afc4
2 changed files with 5 additions and 1 deletions

View File

@ -573,7 +573,10 @@ class Loader(object):
filename = os.path.join( filename = os.path.join(
saved_model_utils.get_assets_dir(self._export_dir), saved_model_utils.get_assets_dir(self._export_dir),
self._asset_file_def[proto.asset_file_def_index].filename) self._asset_file_def[proto.asset_file_def_index].filename)
return tracking.Asset(filename), setattr asset = tracking.Asset(filename)
if not context.executing_eagerly():
ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset.asset_path)
return asset, setattr
def _recreate_function(self, proto): def _recreate_function(self, proto):
return function_deserialization.recreate_function( return function_deserialization.recreate_function(

View File

@ -292,6 +292,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
imported_tensor = imported.f() imported_tensor = imported.f()
with monitored_session.MonitoredSession() as sess: with monitored_session.MonitoredSession() as sess:
imported_output = sess.run(imported_tensor) imported_output = sess.run(imported_tensor)
self.assertLen(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), 1)
self.assertNotEqual(original_output, imported_output) self.assertNotEqual(original_output, imported_output)
with open(imported_output, "r") as f: with open(imported_output, "r") as f:
self.assertEqual("contents", f.read()) self.assertEqual("contents", f.read())