Load networks that override get_config
PiperOrigin-RevId: 327524410 Change-Id: I0e277bb6b1e56d42803e26ed15a98025e11fd4be
This commit is contained in:
parent
345a0e60d0
commit
311c8d233b
@ -380,8 +380,11 @@ class KerasObjectLoader(tf_load.Loader):
|
|||||||
metadata['class_name'] == 'Sequential' or
|
metadata['class_name'] == 'Sequential' or
|
||||||
metadata['class_name'] == 'Functional')
|
metadata['class_name'] == 'Functional')
|
||||||
if not (generic_utils.validate_config(config) and
|
if not (generic_utils.validate_config(config) and
|
||||||
model_is_functional_or_sequential):
|
model_is_functional_or_sequential
|
||||||
return None # Revive as custom model.
|
) or generic_utils.get_registered_object(class_name) is not None:
|
||||||
|
# Model should not be revived as a graph network. Try reviving directly
|
||||||
|
# from config or as a custom model.
|
||||||
|
return None
|
||||||
|
|
||||||
# Revive functional and sequential models as blank model objects for now (
|
# Revive functional and sequential models as blank model objects for now (
|
||||||
# must be initialized to enable setattr tracking and attribute caching).
|
# must be initialized to enable setattr tracking and attribute caching).
|
||||||
|
@ -27,6 +27,7 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
@ -115,6 +116,36 @@ class CustomLayerWithConfig(CustomLayerNoConfig):
|
|||||||
'name': self.name}
|
'name': self.name}
|
||||||
|
|
||||||
|
|
||||||
|
class CustomNetworkDefaultConfig(keras.Model):
|
||||||
|
|
||||||
|
def __init__(self, num_classes, name=None):
|
||||||
|
inputs = keras.Input((2, 3), name='inputs')
|
||||||
|
x = keras.layers.Flatten(name='flatten')(inputs)
|
||||||
|
y = keras.layers.Dense(num_classes, name='outputs')(x)
|
||||||
|
super(CustomNetworkDefaultConfig, self).__init__(inputs, y, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomNetworkWithConfig(CustomNetworkDefaultConfig):
|
||||||
|
|
||||||
|
def __init__(self, num_classes, name=None):
|
||||||
|
super(CustomNetworkWithConfig, self).__init__(num_classes, name=name)
|
||||||
|
self._config_dict = dict(num_classes=num_classes)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return self._config_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
return cls(config['num_classes'], name=config.get('name'))
|
||||||
|
|
||||||
|
|
||||||
|
class CustomNetworkWithConfigName(CustomNetworkWithConfig):
|
||||||
|
|
||||||
|
def __init__(self, num_classes, name=None):
|
||||||
|
super(CustomNetworkWithConfigName, self).__init__(num_classes, name=name)
|
||||||
|
self._config_dict['name'] = self.name
|
||||||
|
|
||||||
|
|
||||||
class TestModelRevive(keras_parameterized.TestCase):
|
class TestModelRevive(keras_parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -244,17 +275,31 @@ class TestModelRevive(keras_parameterized.TestCase):
|
|||||||
self._assert_revived_correctness(model, revived)
|
self._assert_revived_correctness(model, revived)
|
||||||
|
|
||||||
def test_revive_sequential_inputs(self):
|
def test_revive_sequential_inputs(self):
|
||||||
model = keras.models.Sequential(
|
model = keras.models.Sequential([
|
||||||
[keras.Input((None,), dtype=dtypes.string),
|
keras.Input((None,), dtype=dtypes.string),
|
||||||
keras.layers.Lambda(string_ops.string_lower)])
|
keras.layers.Lambda(string_ops.string_lower)
|
||||||
|
])
|
||||||
model.save(self.path, save_format='tf')
|
model.save(self.path, save_format='tf')
|
||||||
revived = keras_load.load(self.path)
|
revived = keras_load.load(self.path)
|
||||||
self.assertEqual(dtypes.string, revived._layers[0].dtype)
|
self.assertEqual(dtypes.string, revived._layers[0].dtype)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('default_config', CustomNetworkDefaultConfig),
|
||||||
|
('with_config', CustomNetworkWithConfig),
|
||||||
|
('with_config_name', CustomNetworkWithConfigName))
|
||||||
|
def test_revive_network(self, model_cls):
|
||||||
|
model = model_cls(8)
|
||||||
|
model.save(self.path, include_optimizer=False, save_format='tf')
|
||||||
|
revived = keras_load.load(self.path, compile=False)
|
||||||
|
self._assert_revived_correctness(model, revived)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
with generic_utils.CustomObjectScope({
|
with generic_utils.CustomObjectScope({
|
||||||
'CustomLayerWithConfig': CustomLayerWithConfig,
|
'CustomLayerWithConfig': CustomLayerWithConfig,
|
||||||
'SubclassedModelWithConfig': SubclassedModelWithConfig}):
|
'CustomNetworkWithConfig': CustomNetworkWithConfig,
|
||||||
|
'CustomNetworkWithConfigName': CustomNetworkWithConfigName,
|
||||||
|
'SubclassedModelWithConfig': SubclassedModelWithConfig
|
||||||
|
}):
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user