Roll forward of cl/316247127: Make sure compiled metrics are accessible after loading from H5 or SavedModel.
PiperOrigin-RevId: 317329395 Change-Id: I33578515f36aa0ba227e75bda52966d493a4bebb
This commit is contained in:
parent
7bd8698f34
commit
4d751f9da4
|
@ -37,7 +37,7 @@ class Container(object):
|
|||
def __init__(self, output_names=None):
|
||||
self._output_names = output_names
|
||||
|
||||
def _build(self, y_pred):
|
||||
def build(self, y_pred):
|
||||
if self._output_names is None:
|
||||
# In Subclass API, output names like 'output_1' are used for
|
||||
# `Metric` names.
|
||||
|
@ -131,9 +131,9 @@ class LossesContainer(Container):
|
|||
]
|
||||
return [self._loss_metric] + per_output_metrics
|
||||
|
||||
def _build(self, y_pred):
|
||||
def build(self, y_pred):
|
||||
"""One-time setup of loss objects."""
|
||||
super(LossesContainer, self)._build(y_pred)
|
||||
super(LossesContainer, self).build(y_pred)
|
||||
|
||||
self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses)
|
||||
self._losses = self._conform_to_outputs(y_pred, self._losses)
|
||||
|
@ -184,7 +184,7 @@ class LossesContainer(Container):
|
|||
sample_weight = self._conform_to_outputs(y_pred, sample_weight)
|
||||
|
||||
if not self._built:
|
||||
self._build(y_pred)
|
||||
self.build(y_pred)
|
||||
|
||||
y_pred = nest.flatten(y_pred)
|
||||
y_true = nest.flatten(y_true)
|
||||
|
@ -295,9 +295,9 @@ class MetricsContainer(Container):
|
|||
return []
|
||||
return self._metrics_in_order
|
||||
|
||||
def _build(self, y_pred, y_true):
|
||||
def build(self, y_pred, y_true):
|
||||
"""One-time setup of metric objects."""
|
||||
super(MetricsContainer, self)._build(y_pred)
|
||||
super(MetricsContainer, self).build(y_pred)
|
||||
|
||||
self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
|
||||
self._metrics = self._conform_to_outputs(y_pred, self._metrics)
|
||||
|
@ -385,7 +385,7 @@ class MetricsContainer(Container):
|
|||
sample_weight = self._conform_to_outputs(y_pred, sample_weight)
|
||||
|
||||
if not self._built:
|
||||
self._build(y_pred, y_true)
|
||||
self.build(y_pred, y_true)
|
||||
|
||||
y_pred = nest.flatten(y_pred)
|
||||
y_true = nest.flatten(y_true) if y_true is not None else []
|
||||
|
|
|
@ -1007,10 +1007,12 @@ def _map_subgraph_network(inputs, outputs):
|
|||
|
||||
def _should_skip_first_node(layer):
|
||||
"""Returns True if the first layer node should not be saved or loaded."""
|
||||
# Networks start with a pre-existing node linking their input to output.
|
||||
# For a sequential model, it is first created with _is_graph_network = False,
|
||||
# we have to keep the _is_graph_network check here.
|
||||
return isinstance(layer, Functional) and layer._is_graph_network
|
||||
# Networks that are constructed with an Input layer/shape start with a
|
||||
# pre-existing node linking their input to output. This node is excluded from
|
||||
# the network config.
|
||||
return (isinstance(layer, Functional) and
|
||||
# Filter out Sequential models without an input shape.
|
||||
isinstance(layer._layers[0], input_layer_module.InputLayer))
|
||||
|
||||
|
||||
def _deserialize_keras_tensors(kwargs, layer_map):
|
||||
|
|
|
@ -436,7 +436,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||
'Instead, in order to instantiate and build your '
|
||||
'model, `call` your model on real tensor data (of '
|
||||
'the correct dtype).')
|
||||
|
||||
super(Model, self).build(input_shape)
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
|
@ -2417,6 +2416,12 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||
|
||||
self._saved_model_inputs_spec = specs
|
||||
|
||||
# Store the input shapes
|
||||
if (self.__class__.__name__ == 'Sequential' and
|
||||
self._build_input_shape is None):
|
||||
self._build_input_shape = nest.map_structure(
|
||||
lambda x: None if x is None else x.shape, specs)
|
||||
|
||||
def _assert_weights_created(self):
|
||||
"""Asserts that all the weights for the model have been created.
|
||||
|
||||
|
|
|
@ -630,10 +630,9 @@ class MeanMetricWrapper(Mean):
|
|||
def from_config(cls, config):
|
||||
# Note that while MeanMetricWrapper itself isn't public, objects of this
|
||||
# class may be created and added to the model by calling model.compile.
|
||||
fn = config.pop('fn', None)
|
||||
if cls is MeanMetricWrapper:
|
||||
fn = get(config.pop('fn'))
|
||||
return cls(fn, **config)
|
||||
|
||||
return cls(get(fn), **config)
|
||||
return super(MeanMetricWrapper, cls).from_config(config)
|
||||
|
||||
|
||||
|
|
|
@ -192,6 +192,7 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint
|
|||
# Compile model.
|
||||
model.compile(**saving_utils.compile_args_from_training_config(
|
||||
training_config, custom_objects))
|
||||
saving_utils.try_build_compiled_arguments(model)
|
||||
|
||||
# Set optimizer weights.
|
||||
if 'optimizer_weights' in f:
|
||||
|
|
|
@ -26,10 +26,12 @@ from absl.testing import parameterized
|
|||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import optimizers
|
||||
|
@ -368,48 +370,54 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
|||
|
||||
|
||||
@keras_parameterized.run_with_all_saved_model_formats
|
||||
class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
|
||||
class TestWholeModelSaving(keras_parameterized.TestCase):
|
||||
|
||||
def _save_model_dir(self, dirname='saved_model'):
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
|
||||
return os.path.join(temp_dir, dirname)
|
||||
|
||||
def _assert_same_weights(self, model, loaded_model,
|
||||
original_optimizer_has_iterations_variable=True):
|
||||
"""Checks that the loaded weighs are the same as the original weights.
|
||||
def _assert_same_weights_and_metrics(self, model, loaded_model):
|
||||
"""Checks that the loaded weights and metrics are the same as the original.
|
||||
|
||||
Args:
|
||||
model: original model
|
||||
loaded_model: loaded model
|
||||
original_optimizer_has_iterations_variable: If the original optimizer
|
||||
uses an iterations variable. The loaded model will have a v2
|
||||
optimizer, which always contains an iterations variable. So when
|
||||
comparing the weights, the first variable in the loaded optimizer
|
||||
weights may need to be ignored.
|
||||
"""
|
||||
self.assertAllClose(model.weights, loaded_model.weights)
|
||||
|
||||
if loaded_model.optimizer:
|
||||
if testing_utils.get_save_format() == 'tf':
|
||||
# TODO(b/153110928): Keras TF format doesn't restore optimizer weights
|
||||
# currently.
|
||||
return
|
||||
if original_optimizer_has_iterations_variable:
|
||||
self.assertAllClose(model.optimizer.weights,
|
||||
loaded_model.optimizer.weights)
|
||||
else:
|
||||
self.assertAllClose(model.optimizer.weights,
|
||||
loaded_model.optimizer.weights[1:])
|
||||
self.assertAllClose(model.optimizer.weights,
|
||||
loaded_model.optimizer.weights)
|
||||
|
||||
def test_sequential_model_saving(self):
|
||||
# In V1/Graph mode, the model isn't built, so the metrics are not loaded
|
||||
# immediately (requires model to be called on some data before building
|
||||
# metrics).
|
||||
check_metrics = tf2.enabled() and context.executing_eagerly()
|
||||
|
||||
if check_metrics:
|
||||
self.assertAllEqual([m.name for m in model.metrics],
|
||||
[m.name for m in loaded_model.metrics])
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_save_and_load(self):
|
||||
saved_model_dir = self._save_model_dir()
|
||||
save_format = testing_utils.get_save_format()
|
||||
|
||||
if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
|
||||
return # HDF5 format currently does not allow saving classed models.
|
||||
|
||||
with self.cached_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(2, input_shape=(3,)))
|
||||
model.add(keras.layers.RepeatVector(3))
|
||||
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
|
||||
model = testing_utils.get_model_from_layers(
|
||||
[keras.layers.Dense(2),
|
||||
keras.layers.RepeatVector(3),
|
||||
keras.layers.TimeDistributed(keras.layers.Dense(3))],
|
||||
input_shape=(3,))
|
||||
model.compile(
|
||||
loss=keras.losses.MSE,
|
||||
optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001),
|
||||
|
@ -432,43 +440,35 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
|
|||
out = model.predict(x)
|
||||
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||
|
||||
new_model = keras.models.load_model(saved_model_dir)
|
||||
self._assert_same_weights(model, new_model)
|
||||
loaded_model = keras.models.load_model(saved_model_dir)
|
||||
self._assert_same_weights_and_metrics(model, loaded_model)
|
||||
|
||||
out2 = new_model.predict(x)
|
||||
out2 = loaded_model.predict(x)
|
||||
self.assertAllClose(out, out2, atol=1e-05)
|
||||
|
||||
# test that new updates are the same with both models
|
||||
model.train_on_batch(x, y)
|
||||
new_model.train_on_batch(x, y)
|
||||
|
||||
eval_out = model.evaluate(x, y)
|
||||
eval_out2 = new_model.evaluate(x, y)
|
||||
eval_out2 = loaded_model.evaluate(x, y)
|
||||
self.assertArrayNear(eval_out, eval_out2, 0.001)
|
||||
|
||||
out = model.predict(x)
|
||||
out2 = new_model.predict(x)
|
||||
# The model has been trained on two batches. So the tolerance is larger.
|
||||
self.assertAllClose(out, out2, atol=0.01)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_sequential_model_saving_without_input_shape(self):
|
||||
saved_model_dir = self._save_model_dir()
|
||||
save_format = testing_utils.get_save_format()
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
with self.cached_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(2))
|
||||
model.add(keras.layers.RepeatVector(3))
|
||||
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
|
||||
model.compile(
|
||||
loss=keras.losses.MSE,
|
||||
optimizer=keras.optimizers.RMSprop(lr=0.0001),
|
||||
optimizer='rmsprop',
|
||||
metrics=[
|
||||
keras.metrics.categorical_accuracy,
|
||||
keras.metrics.CategoricalAccuracy()
|
||||
keras.metrics.CategoricalAccuracy(name='cat_acc')
|
||||
],
|
||||
weighted_metrics=[
|
||||
keras.metrics.categorical_accuracy,
|
||||
keras.metrics.CategoricalAccuracy()
|
||||
keras.metrics.CategoricalAccuracy(name='cat_acc2')
|
||||
],
|
||||
sample_weight_mode='temporal')
|
||||
x = np.random.random((1, 3))
|
||||
|
@ -479,12 +479,13 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
|
|||
model.save(saved_model_dir, save_format=save_format)
|
||||
|
||||
new_model = keras.models.load_model(saved_model_dir)
|
||||
self._assert_same_weights(
|
||||
model, new_model, original_optimizer_has_iterations_variable=False)
|
||||
|
||||
self._assert_same_weights_and_metrics(model, new_model)
|
||||
|
||||
out2 = new_model.predict(x)
|
||||
self.assertAllClose(out, out2, atol=1e-05)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_sequential_model_saving_without_compile(self):
|
||||
saved_model_dir = self._save_model_dir()
|
||||
save_format = testing_utils.get_save_format()
|
||||
|
@ -501,7 +502,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
|
|||
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||
|
||||
new_model = keras.models.load_model(saved_model_dir)
|
||||
self._assert_same_weights(model, new_model)
|
||||
self._assert_same_weights_and_metrics(model, new_model)
|
||||
|
||||
out2 = new_model.predict(x)
|
||||
self.assertAllClose(out, out2, atol=1e-05)
|
||||
|
@ -535,42 +536,11 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
|
|||
saved_model_dir,
|
||||
custom_objects={'CustomOp': CustomOp,
|
||||
'custom_loss': custom_loss})
|
||||
self._assert_same_weights(model, new_model)
|
||||
self._assert_same_weights_and_metrics(model, new_model)
|
||||
|
||||
out2 = new_model.predict(x)
|
||||
self.assertAllClose(out, out2, atol=1e-05)
|
||||
|
||||
def test_functional_model_saving(self):
|
||||
saved_model_dir = self._save_model_dir()
|
||||
save_format = testing_utils.get_save_format()
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
inputs = keras.layers.Input(shape=(3,))
|
||||
x = keras.layers.Dense(2)(inputs)
|
||||
output = keras.layers.Dense(3)(x)
|
||||
|
||||
model = keras.models.Model(inputs, output)
|
||||
model.compile(
|
||||
loss=keras.losses.MSE,
|
||||
optimizer=keras.optimizers.RMSprop(lr=0.0001),
|
||||
metrics=[
|
||||
keras.metrics.categorical_accuracy,
|
||||
keras.metrics.CategoricalAccuracy()
|
||||
],
|
||||
weighted_metrics=[
|
||||
keras.metrics.categorical_accuracy,
|
||||
keras.metrics.CategoricalAccuracy()
|
||||
])
|
||||
x = np.random.random((1, 3))
|
||||
y = np.random.random((1, 3))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||
model = keras.models.load_model(saved_model_dir)
|
||||
|
||||
out2 = model.predict(x)
|
||||
self.assertAllClose(out, out2, atol=1e-05)
|
||||
|
||||
def test_saving_without_compilation(self):
|
||||
saved_model_dir = self._save_model_dir()
|
||||
save_format = testing_utils.get_save_format()
|
||||
|
|
|
@ -129,6 +129,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
|
|||
if training_config is not None:
|
||||
model.compile(**saving_utils.compile_args_from_training_config(
|
||||
training_config))
|
||||
saving_utils.try_build_compiled_arguments(model)
|
||||
else:
|
||||
logging.warning('No training configuration found in save file, so the '
|
||||
'model was *not* compiled. Compile it manually.')
|
||||
|
|
|
@ -27,6 +27,7 @@ from tensorflow.python.keras import losses
|
|||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import version_utils
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
|
@ -307,3 +308,16 @@ def _enforce_names_consistency(specs):
|
|||
if name_inconsistency:
|
||||
specs = nest.map_structure(_clear_name, specs)
|
||||
return specs
|
||||
|
||||
|
||||
def try_build_compiled_arguments(model):
|
||||
if (not version_utils.is_v1_layer_or_model(model) and
|
||||
model.outputs is not None):
|
||||
try:
|
||||
model.compiled_loss.build(model.outputs)
|
||||
model.compiled_metrics.build(model.outputs, model.outputs)
|
||||
except: # pylint: disable=bare-except
|
||||
logging.warning(
|
||||
'Compiled the loaded model, but the compiled metrics have yet to '
|
||||
'be built. `model.compile_metrics` will be empty until you train '
|
||||
'or evaluate the model.')
|
||||
|
|
Loading…
Reference in New Issue