Keras SavedModel: Ignore custom metrics failure when compile=False

This commit is contained in:
Lukas Geiger 2020-11-30 18:37:00 +00:00
parent df05368807
commit c4e6c635de
2 changed files with 37 additions and 5 deletions

View File

@ -135,7 +135,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
# Recreate layers and metrics using the info stored in the metadata.
keras_loader = KerasObjectLoader(metadata, object_graph_def)
keras_loader.load_layers()
keras_loader.load_layers(compile=compile)
# Generate a dictionary of all loaded nodes.
nodes_to_load = {'root': None}
@ -360,7 +360,7 @@ class KerasObjectLoader(object):
obj_child, child_proto, child_id)
self.loaded_nodes[child_id] = obj_child, setter
def load_layers(self):
def load_layers(self, compile=True): # pylint: disable=redefined-builtin
"""Load all layer nodes from the metadata."""
# Load metrics after models and layers, since it's likely that models
# and layers will create the metric when initialized (this avoids wasting
@ -376,9 +376,21 @@ class KerasObjectLoader(object):
node_metadata.metadata)
for node_metadata in metric_list:
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
node_metadata.node_id, node_metadata.identifier,
node_metadata.metadata)
try:
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
node_metadata.node_id, node_metadata.identifier,
node_metadata.metadata)
except ValueError:
# Metrics are only needed when the model is compiled later. We ignore
# errors when trying to load custom metrics when `compile=False` until
# custom metrics are serialized properly (b/135550038).
if compile:
raise
logging.warning('Unable to restore custom metric. Please ensure that '
'the layer implements `get_config` and `from_config` '
'when saving. In addition, please use the '
'`custom_objects` arg when calling `load_model()`.')
def _load_layer(self, node_id, identifier, metadata):
"""Load a single layer from a SavedUserObject proto."""

View File

@ -1147,6 +1147,26 @@ class MetricTest(test.TestCase, parameterized.TestCase):
self._test_metric_save_and_load(
metric, self._save_model_dir(), 1, test_sample_weight=False)
@keras_parameterized.run_with_all_model_types
def test_custom_metric_model(self):
class CustomMetric(keras.metrics.MeanSquaredError):
pass
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
model.compile(
loss='mse',
optimizer='rmsprop',
metrics=[CustomMetric()])
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
with self.assertRaisesRegex(ValueError, 'custom_objects'):
keras_load.load(saved_model_dir)
keras_load.load(saved_model_dir, compile=False)
if __name__ == '__main__':
test.main()