Merge pull request #45278 from lgeiger:fix-custom-metric-saved-model
PiperOrigin-RevId: 347069453 Change-Id: Ie438b2f2279ae5fea0d4b17bc67bb3478690dd0c
This commit is contained in:
commit
8144f8af37
@ -146,7 +146,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
|
||||
|
||||
# Recreate layers and metrics using the info stored in the metadata.
|
||||
keras_loader = KerasObjectLoader(metadata, object_graph_def)
|
||||
keras_loader.load_layers()
|
||||
keras_loader.load_layers(compile=compile)
|
||||
|
||||
# Generate a dictionary of all loaded nodes.
|
||||
nodes_to_load = {'root': None}
|
||||
@ -371,7 +371,7 @@ class KerasObjectLoader(object):
|
||||
obj_child, child_proto, child_id)
|
||||
self.loaded_nodes[child_id] = obj_child, setter
|
||||
|
||||
def load_layers(self):
|
||||
def load_layers(self, compile=True): # pylint: disable=redefined-builtin
|
||||
"""Load all layer nodes from the metadata."""
|
||||
# Load metrics after models and layers, since it's likely that models
|
||||
# and layers will create the metric when initialized (this avoids wasting
|
||||
@ -387,9 +387,20 @@ class KerasObjectLoader(object):
|
||||
node_metadata.metadata)
|
||||
|
||||
for node_metadata in metric_list:
|
||||
try:
|
||||
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
|
||||
node_metadata.node_id, node_metadata.identifier,
|
||||
node_metadata.metadata)
|
||||
except ValueError:
|
||||
# Metrics are only needed when the model is compiled later. We ignore
|
||||
# errors when trying to load custom metrics when `compile=False` until
|
||||
# custom metrics are serialized properly (b/135550038).
|
||||
if compile:
|
||||
raise
|
||||
logging.warning('Unable to restore custom metric. Please ensure that '
|
||||
'the layer implements `get_config` and `from_config` '
|
||||
'when saving. In addition, please use the '
|
||||
'`custom_objects` arg when calling `load_model()`.')
|
||||
|
||||
def _load_layer(self, node_id, identifier, metadata):
|
||||
"""Load a single layer from a SavedUserObject proto."""
|
||||
|
@ -1159,6 +1159,22 @@ class MetricTest(test.TestCase, parameterized.TestCase):
|
||||
self._test_metric_save_and_load(
|
||||
metric, self._save_model_dir(), 1, test_sample_weight=False)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_custom_metric_model(self):
|
||||
|
||||
class CustomMetric(keras.metrics.MeanSquaredError):
|
||||
pass
|
||||
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
model.compile(loss='mse', optimizer='rmsprop', metrics=[CustomMetric()])
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
with self.assertRaisesRegex(ValueError, 'custom_objects'):
|
||||
keras_load.load(saved_model_dir)
|
||||
|
||||
keras_load.load(saved_model_dir, compile=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user