Ignore get_configs that aren't JSON-serializable when saving checkpoints
This is extra "nice to have" metadata, and otherwise it looks like a checkpointing error. Not worth bothering people about. PiperOrigin-RevId: 235247436
This commit is contained in:
parent
fbe459fdbe
commit
5528b56647
@ -851,10 +851,14 @@ class Trackable(object):
|
||||
"""Serializes `self.get_config()` for saving."""
|
||||
dereferenced_self = weak_self()
|
||||
if dereferenced_self:
|
||||
return json.dumps(
|
||||
dereferenced_self,
|
||||
default=serialization.get_json_type,
|
||||
sort_keys=True).encode("utf8")
|
||||
try:
|
||||
return json.dumps(
|
||||
dereferenced_self,
|
||||
default=serialization.get_json_type,
|
||||
sort_keys=True).encode("utf8")
|
||||
except TypeError:
|
||||
# Even if get_config worked objects may have produced garbage.
|
||||
return ""
|
||||
else:
|
||||
return ""
|
||||
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
|
||||
|
@ -83,6 +83,19 @@ class InterfaceTests(test.TestCase):
|
||||
with self.assertRaisesRegexp(AssertionError, "foo_attr"):
|
||||
status.assert_consumed()
|
||||
|
||||
def testBuggyGetConfig(self):
|
||||
|
||||
class NotSerializable(object):
|
||||
pass
|
||||
|
||||
class GetConfigRaisesError(base.Trackable):
|
||||
|
||||
def get_config(self):
|
||||
return NotSerializable()
|
||||
|
||||
util.Checkpoint(obj=GetConfigRaisesError()).save(
|
||||
os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user