Merge pull request #45278 from lgeiger:fix-custom-metric-saved-model

PiperOrigin-RevId: 347069453
Change-Id: Ie438b2f2279ae5fea0d4b17bc67bb3478690dd0c
This commit is contained in:
TensorFlower Gardener 2020-12-11 14:01:24 -08:00
commit 8144f8af37
2 changed files with 32 additions and 5 deletions

View File

@ -146,7 +146,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
# Recreate layers and metrics using the info stored in the metadata.
keras_loader = KerasObjectLoader(metadata, object_graph_def)
keras_loader.load_layers()
keras_loader.load_layers(compile=compile)
# Generate a dictionary of all loaded nodes.
nodes_to_load = {'root': None}
@ -371,7 +371,7 @@ class KerasObjectLoader(object):
obj_child, child_proto, child_id)
self.loaded_nodes[child_id] = obj_child, setter
def load_layers(self):
def load_layers(self, compile=True): # pylint: disable=redefined-builtin
"""Load all layer nodes from the metadata."""
# Load metrics after models and layers, since it's likely that models
# and layers will create the metric when initialized (this avoids wasting
@ -387,9 +387,20 @@ class KerasObjectLoader(object):
node_metadata.metadata)
for node_metadata in metric_list:
try:
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
node_metadata.node_id, node_metadata.identifier,
node_metadata.metadata)
except ValueError:
# Metrics are only needed when the model is compiled later. We ignore
# errors when trying to load custom metrics when `compile=False` until
# custom metrics are serialized properly (b/135550038).
if compile:
raise
logging.warning('Unable to restore custom metric. Please ensure that '
'the layer implements `get_config` and `from_config` '
'when saving. In addition, please use the '
'`custom_objects` arg when calling `load_model()`.')
def _load_layer(self, node_id, identifier, metadata):
"""Load a single layer from a SavedUserObject proto."""

View File

@ -1159,6 +1159,22 @@ class MetricTest(test.TestCase, parameterized.TestCase):
self._test_metric_save_and_load(
metric, self._save_model_dir(), 1, test_sample_weight=False)
@keras_parameterized.run_with_all_model_types
def test_custom_metric_model(self):
class CustomMetric(keras.metrics.MeanSquaredError):
pass
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
model.compile(loss='mse', optimizer='rmsprop', metrics=[CustomMetric()])
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
with self.assertRaisesRegex(ValueError, 'custom_objects'):
keras_load.load(saved_model_dir)
keras_load.load(saved_model_dir, compile=False)
if __name__ == '__main__':
test.main()