Load networks that override get_config

PiperOrigin-RevId: 327524410
Change-Id: I0e277bb6b1e56d42803e26ed15a98025e11fd4be
This commit is contained in:
Philip Pham 2020-08-19 15:59:24 -07:00 committed by TensorFlower Gardener
parent 345a0e60d0
commit 311c8d233b
2 changed files with 54 additions and 6 deletions

View File

@ -380,8 +380,11 @@ class KerasObjectLoader(tf_load.Loader):
metadata['class_name'] == 'Sequential' or
metadata['class_name'] == 'Functional')
if not (generic_utils.validate_config(config) and
model_is_functional_or_sequential):
return None # Revive as custom model.
model_is_functional_or_sequential
) or generic_utils.get_registered_object(class_name) is not None:
# Model should not be revived as a graph network. Try reviving directly
# from config or as a custom model.
return None
# Revive functional and sequential models as blank model objects for now (
# must be initialized to enable setattr tracking and attribute caching).

View File

@ -27,6 +27,7 @@ from __future__ import print_function
import os
import shutil
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
@ -115,6 +116,36 @@ class CustomLayerWithConfig(CustomLayerNoConfig):
'name': self.name}
class CustomNetworkDefaultConfig(keras.Model):
def __init__(self, num_classes, name=None):
inputs = keras.Input((2, 3), name='inputs')
x = keras.layers.Flatten(name='flatten')(inputs)
y = keras.layers.Dense(num_classes, name='outputs')(x)
super(CustomNetworkDefaultConfig, self).__init__(inputs, y, name=name)
class CustomNetworkWithConfig(CustomNetworkDefaultConfig):
def __init__(self, num_classes, name=None):
super(CustomNetworkWithConfig, self).__init__(num_classes, name=name)
self._config_dict = dict(num_classes=num_classes)
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(config['num_classes'], name=config.get('name'))
class CustomNetworkWithConfigName(CustomNetworkWithConfig):
def __init__(self, num_classes, name=None):
super(CustomNetworkWithConfigName, self).__init__(num_classes, name=name)
self._config_dict['name'] = self.name
class TestModelRevive(keras_parameterized.TestCase):
def setUp(self):
@ -244,17 +275,31 @@ class TestModelRevive(keras_parameterized.TestCase):
self._assert_revived_correctness(model, revived)
def test_revive_sequential_inputs(self):
model = keras.models.Sequential(
[keras.Input((None,), dtype=dtypes.string),
keras.layers.Lambda(string_ops.string_lower)])
model = keras.models.Sequential([
keras.Input((None,), dtype=dtypes.string),
keras.layers.Lambda(string_ops.string_lower)
])
model.save(self.path, save_format='tf')
revived = keras_load.load(self.path)
self.assertEqual(dtypes.string, revived._layers[0].dtype)
@parameterized.named_parameters(
('default_config', CustomNetworkDefaultConfig),
('with_config', CustomNetworkWithConfig),
('with_config_name', CustomNetworkWithConfigName))
def test_revive_network(self, model_cls):
model = model_cls(8)
model.save(self.path, include_optimizer=False, save_format='tf')
revived = keras_load.load(self.path, compile=False)
self._assert_revived_correctness(model, revived)
if __name__ == '__main__':
ops.enable_eager_execution()
with generic_utils.CustomObjectScope({
'CustomLayerWithConfig': CustomLayerWithConfig,
'SubclassedModelWithConfig': SubclassedModelWithConfig}):
'CustomNetworkWithConfig': CustomNetworkWithConfig,
'CustomNetworkWithConfigName': CustomNetworkWithConfigName,
'SubclassedModelWithConfig': SubclassedModelWithConfig
}):
test.main()