Keras SavedModel: Ignore custom metrics failure when compile=False
This commit is contained in:
parent
df05368807
commit
c4e6c635de
@ -135,7 +135,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
|
||||
|
||||
# Recreate layers and metrics using the info stored in the metadata.
|
||||
keras_loader = KerasObjectLoader(metadata, object_graph_def)
|
||||
keras_loader.load_layers()
|
||||
keras_loader.load_layers(compile=compile)
|
||||
|
||||
# Generate a dictionary of all loaded nodes.
|
||||
nodes_to_load = {'root': None}
|
||||
@ -360,7 +360,7 @@ class KerasObjectLoader(object):
|
||||
obj_child, child_proto, child_id)
|
||||
self.loaded_nodes[child_id] = obj_child, setter
|
||||
|
||||
def load_layers(self):
|
||||
def load_layers(self, compile=True): # pylint: disable=redefined-builtin
|
||||
"""Load all layer nodes from the metadata."""
|
||||
# Load metrics after models and layers, since it's likely that models
|
||||
# and layers will create the metric when initialized (this avoids wasting
|
||||
@ -376,9 +376,21 @@ class KerasObjectLoader(object):
|
||||
node_metadata.metadata)
|
||||
|
||||
for node_metadata in metric_list:
|
||||
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
|
||||
node_metadata.node_id, node_metadata.identifier,
|
||||
node_metadata.metadata)
|
||||
try:
|
||||
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
|
||||
node_metadata.node_id, node_metadata.identifier,
|
||||
node_metadata.metadata)
|
||||
except ValueError:
|
||||
# Metrics are only needed when the model is compiled later. We ignore
|
||||
# errors when trying to load custom metrics when `compile=False` until
|
||||
# custom metrics are serialized properly (b/135550038).
|
||||
if compile:
|
||||
raise
|
||||
logging.warning('Unable to restore custom metric. Please ensure that '
|
||||
'the layer implements `get_config` and `from_config` '
|
||||
'when saving. In addition, please use the '
|
||||
'`custom_objects` arg when calling `load_model()`.')
|
||||
|
||||
|
||||
def _load_layer(self, node_id, identifier, metadata):
|
||||
"""Load a single layer from a SavedUserObject proto."""
|
||||
|
@ -1147,6 +1147,26 @@ class MetricTest(test.TestCase, parameterized.TestCase):
|
||||
self._test_metric_save_and_load(
|
||||
metric, self._save_model_dir(), 1, test_sample_weight=False)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_custom_metric_model(self):
|
||||
|
||||
class CustomMetric(keras.metrics.MeanSquaredError):
|
||||
pass
|
||||
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer='rmsprop',
|
||||
metrics=[CustomMetric()])
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
with self.assertRaisesRegex(ValueError, 'custom_objects'):
|
||||
keras_load.load(saved_model_dir)
|
||||
|
||||
keras_load.load(saved_model_dir, compile=False)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user