Load networks that override get_config
PiperOrigin-RevId: 327524410 Change-Id: I0e277bb6b1e56d42803e26ed15a98025e11fd4be
This commit is contained in:
parent
345a0e60d0
commit
311c8d233b
@ -380,8 +380,11 @@ class KerasObjectLoader(tf_load.Loader):
|
||||
metadata['class_name'] == 'Sequential' or
|
||||
metadata['class_name'] == 'Functional')
|
||||
if not (generic_utils.validate_config(config) and
|
||||
model_is_functional_or_sequential):
|
||||
return None # Revive as custom model.
|
||||
model_is_functional_or_sequential
|
||||
) or generic_utils.get_registered_object(class_name) is not None:
|
||||
# Model should not be revived as a graph network. Try reviving directly
|
||||
# from config or as a custom model.
|
||||
return None
|
||||
|
||||
# Revive functional and sequential models as blank model objects for now (
|
||||
# must be initialized to enable setattr tracking and attribute caching).
|
||||
|
@ -27,6 +27,7 @@ from __future__ import print_function
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
@ -115,6 +116,36 @@ class CustomLayerWithConfig(CustomLayerNoConfig):
|
||||
'name': self.name}
|
||||
|
||||
|
||||
class CustomNetworkDefaultConfig(keras.Model):
|
||||
|
||||
def __init__(self, num_classes, name=None):
|
||||
inputs = keras.Input((2, 3), name='inputs')
|
||||
x = keras.layers.Flatten(name='flatten')(inputs)
|
||||
y = keras.layers.Dense(num_classes, name='outputs')(x)
|
||||
super(CustomNetworkDefaultConfig, self).__init__(inputs, y, name=name)
|
||||
|
||||
|
||||
class CustomNetworkWithConfig(CustomNetworkDefaultConfig):
|
||||
|
||||
def __init__(self, num_classes, name=None):
|
||||
super(CustomNetworkWithConfig, self).__init__(num_classes, name=name)
|
||||
self._config_dict = dict(num_classes=num_classes)
|
||||
|
||||
def get_config(self):
|
||||
return self._config_dict
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(config['num_classes'], name=config.get('name'))
|
||||
|
||||
|
||||
class CustomNetworkWithConfigName(CustomNetworkWithConfig):
|
||||
|
||||
def __init__(self, num_classes, name=None):
|
||||
super(CustomNetworkWithConfigName, self).__init__(num_classes, name=name)
|
||||
self._config_dict['name'] = self.name
|
||||
|
||||
|
||||
class TestModelRevive(keras_parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -244,17 +275,31 @@ class TestModelRevive(keras_parameterized.TestCase):
|
||||
self._assert_revived_correctness(model, revived)
|
||||
|
||||
def test_revive_sequential_inputs(self):
|
||||
model = keras.models.Sequential(
|
||||
[keras.Input((None,), dtype=dtypes.string),
|
||||
keras.layers.Lambda(string_ops.string_lower)])
|
||||
model = keras.models.Sequential([
|
||||
keras.Input((None,), dtype=dtypes.string),
|
||||
keras.layers.Lambda(string_ops.string_lower)
|
||||
])
|
||||
model.save(self.path, save_format='tf')
|
||||
revived = keras_load.load(self.path)
|
||||
self.assertEqual(dtypes.string, revived._layers[0].dtype)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('default_config', CustomNetworkDefaultConfig),
|
||||
('with_config', CustomNetworkWithConfig),
|
||||
('with_config_name', CustomNetworkWithConfigName))
|
||||
def test_revive_network(self, model_cls):
|
||||
model = model_cls(8)
|
||||
model.save(self.path, include_optimizer=False, save_format='tf')
|
||||
revived = keras_load.load(self.path, compile=False)
|
||||
self._assert_revived_correctness(model, revived)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
with generic_utils.CustomObjectScope({
|
||||
'CustomLayerWithConfig': CustomLayerWithConfig,
|
||||
'SubclassedModelWithConfig': SubclassedModelWithConfig}):
|
||||
'CustomNetworkWithConfig': CustomNetworkWithConfig,
|
||||
'CustomNetworkWithConfigName': CustomNetworkWithConfigName,
|
||||
'SubclassedModelWithConfig': SubclassedModelWithConfig
|
||||
}):
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user