Make sure compiled metrics are accessible after loading from H5 or SavedModel.

With the updated compile and Sequential changes, compiled metrics are not added to `model.metrics` until fit is called. This unifies the behavior of creating and compiling a new model, since sometimes the metrics can't be generated until the input and outputs are known. When loading a pre-existing model, however, it makes more sense to have the metrics be immediately available. These changes build the compiled metrics right after loading the model.

This CL also modifies Sequential so that if the build_input_shape is available, the `from_config` always builds model no matter what context.

PiperOrigin-RevId: 315800928
Change-Id: Iea73bdedf3e05bd5595be892661ee163f2279c0e
This commit is contained in:
Katherine Wu 2020-06-10 17:22:45 -07:00 committed by TensorFlower Gardener
parent 1936a8120d
commit c3da1a69c3
6 changed files with 71 additions and 84 deletions

View File

@ -37,7 +37,7 @@ class Container(object):
def __init__(self, output_names=None): def __init__(self, output_names=None):
self._output_names = output_names self._output_names = output_names
def _build(self, y_pred): def build(self, y_pred):
if self._output_names is None: if self._output_names is None:
# In Subclass API, output names like 'output_1' are used for # In Subclass API, output names like 'output_1' are used for
# `Metric` names. # `Metric` names.
@ -131,9 +131,9 @@ class LossesContainer(Container):
] ]
return [self._loss_metric] + per_output_metrics return [self._loss_metric] + per_output_metrics
def _build(self, y_pred): def build(self, y_pred):
"""One-time setup of loss objects.""" """One-time setup of loss objects."""
super(LossesContainer, self)._build(y_pred) super(LossesContainer, self).build(y_pred)
self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses) self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses)
self._losses = self._conform_to_outputs(y_pred, self._losses) self._losses = self._conform_to_outputs(y_pred, self._losses)
@ -184,7 +184,7 @@ class LossesContainer(Container):
sample_weight = self._conform_to_outputs(y_pred, sample_weight) sample_weight = self._conform_to_outputs(y_pred, sample_weight)
if not self._built: if not self._built:
self._build(y_pred) self.build(y_pred)
y_pred = nest.flatten(y_pred) y_pred = nest.flatten(y_pred)
y_true = nest.flatten(y_true) y_true = nest.flatten(y_true)
@ -295,9 +295,9 @@ class MetricsContainer(Container):
return [] return []
return self._metrics_in_order return self._metrics_in_order
def _build(self, y_pred, y_true): def build(self, y_pred, y_true):
"""One-time setup of metric objects.""" """One-time setup of metric objects."""
super(MetricsContainer, self)._build(y_pred) super(MetricsContainer, self).build(y_pred)
self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics) self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
self._metrics = self._conform_to_outputs(y_pred, self._metrics) self._metrics = self._conform_to_outputs(y_pred, self._metrics)
@ -385,7 +385,7 @@ class MetricsContainer(Container):
sample_weight = self._conform_to_outputs(y_pred, sample_weight) sample_weight = self._conform_to_outputs(y_pred, sample_weight)
if not self._built: if not self._built:
self._build(y_pred, y_true) self.build(y_pred, y_true)
y_pred = nest.flatten(y_pred) y_pred = nest.flatten(y_pred)
y_true = nest.flatten(y_true) if y_true is not None else [] y_true = nest.flatten(y_true) if y_true is not None else []

View File

@ -1007,10 +1007,12 @@ def _map_subgraph_network(inputs, outputs):
def _should_skip_first_node(layer): def _should_skip_first_node(layer):
"""Returns True if the first layer node should not be saved or loaded.""" """Returns True if the first layer node should not be saved or loaded."""
# Networks start with a pre-existing node linking their input to output. # Networks that are constructed with an Input layer/shape start with a
# For a sequential model, it is first created with _is_graph_network = False, # pre-existing node linking their input to output. This node is excluded from
# we have to keep the _is_graph_network check here. # the network config.
return isinstance(layer, Functional) and layer._is_graph_network return (isinstance(layer, Functional) and
# Filter out Sequential models without an input shape.
isinstance(layer._layers[0], input_layer_module.InputLayer))
def _deserialize_keras_tensors(kwargs, layer_map): def _deserialize_keras_tensors(kwargs, layer_map):

View File

@ -436,7 +436,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
'Instead, in order to instantiate and build your ' 'Instead, in order to instantiate and build your '
'model, `call` your model on real tensor data (of ' 'model, `call` your model on real tensor data (of '
'the correct dtype).') 'the correct dtype).')
super(Model, self).build(input_shape) super(Model, self).build(input_shape)
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
@ -2382,6 +2381,12 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
self._saved_model_inputs_spec = specs self._saved_model_inputs_spec = specs
# Store the input shapes
if (self.__class__.__name__ == 'Sequential' and
self._build_input_shape is None):
self._build_input_shape = nest.map_structure(
lambda x: None if x is None else x.shape, specs)
def _assert_weights_created(self): def _assert_weights_created(self):
"""Asserts that all the weights for the model have been created. """Asserts that all the weights for the model have been created.

View File

@ -30,6 +30,7 @@ from tensorflow.python.keras import optimizers
from tensorflow.python.keras.saving import model_config as model_config_lib from tensorflow.python.keras.saving import model_config as model_config_lib
from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.ops import variables as variables_module from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -193,6 +194,10 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint
model.compile(**saving_utils.compile_args_from_training_config( model.compile(**saving_utils.compile_args_from_training_config(
training_config, custom_objects)) training_config, custom_objects))
if not version_utils.is_v1_layer_or_model(model):
model.compiled_metrics.build(model.outputs, model.outputs)
model.compiled_loss.build(model.outputs)
# Set optimizer weights. # Set optimizer weights.
if 'optimizer_weights' in f: if 'optimizer_weights' in f:
try: try:

View File

@ -26,10 +26,12 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import optimizers from tensorflow.python.keras import optimizers
@ -368,48 +370,54 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
@keras_parameterized.run_with_all_saved_model_formats @keras_parameterized.run_with_all_saved_model_formats
class TestWholeModelSaving(test.TestCase, parameterized.TestCase): class TestWholeModelSaving(keras_parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'): def _save_model_dir(self, dirname='saved_model'):
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
return os.path.join(temp_dir, dirname) return os.path.join(temp_dir, dirname)
def _assert_same_weights(self, model, loaded_model, def _assert_same_weights_and_metrics(self, model, loaded_model):
original_optimizer_has_iterations_variable=True): """Checks that the loaded weights and metrics are the same as the original.
"""Checks that the loaded weighs are the same as the original weights.
Args: Args:
model: original model model: original model
loaded_model: loaded model loaded_model: loaded model
original_optimizer_has_iterations_variable: If the original optimizer
uses an iterations variable. The loaded model will have a v2
optimizer, which always contains an iterations variable. So when
comparing the weights, the first variable in the loaded optimizer
weights may need to be ignored.
""" """
self.assertAllClose(model.weights, loaded_model.weights) self.assertAllClose(model.weights, loaded_model.weights)
if loaded_model.optimizer: if loaded_model.optimizer:
if testing_utils.get_save_format() == 'tf': if testing_utils.get_save_format() == 'tf':
# TODO(b/153110928): Keras TF format doesn't restore optimizer weights # TODO(b/153110928): Keras TF format doesn't restore optimizer weights
# currently. # currently.
return return
if original_optimizer_has_iterations_variable:
self.assertAllClose(model.optimizer.weights, self.assertAllClose(model.optimizer.weights,
loaded_model.optimizer.weights) loaded_model.optimizer.weights)
else:
self.assertAllClose(model.optimizer.weights,
loaded_model.optimizer.weights[1:])
def test_sequential_model_saving(self): # In V1/Graph mode, the model isn't built, so the metrics are not loaded
# immediately (requires model to be called on some data before building
# metrics).
check_metrics = tf2.enabled and context.executing_eagerly()
if check_metrics:
self.assertAllEqual([m.name for m in model.metrics],
[m.name for m in loaded_model.metrics])
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_save_and_load(self):
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()
if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
return # HDF5 format currently does not allow saving classed models.
with self.cached_session(): with self.cached_session():
model = keras.models.Sequential() model = testing_utils.get_model_from_layers(
model.add(keras.layers.Dense(2, input_shape=(3,))) [keras.layers.Dense(2),
model.add(keras.layers.RepeatVector(3)) keras.layers.RepeatVector(3),
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) keras.layers.TimeDistributed(keras.layers.Dense(3))],
input_shape=(3,))
model.compile( model.compile(
loss=keras.losses.MSE, loss=keras.losses.MSE,
optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001), optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001),
@ -432,43 +440,35 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
out = model.predict(x) out = model.predict(x)
keras.models.save_model(model, saved_model_dir, save_format=save_format) keras.models.save_model(model, saved_model_dir, save_format=save_format)
new_model = keras.models.load_model(saved_model_dir) loaded_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights(model, new_model) self._assert_same_weights_and_metrics(model, loaded_model)
out2 = new_model.predict(x) out2 = loaded_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05) self.assertAllClose(out, out2, atol=1e-05)
# test that new updates are the same with both models
model.train_on_batch(x, y)
new_model.train_on_batch(x, y)
eval_out = model.evaluate(x, y) eval_out = model.evaluate(x, y)
eval_out2 = new_model.evaluate(x, y) eval_out2 = loaded_model.evaluate(x, y)
self.assertArrayNear(eval_out, eval_out2, 0.001) self.assertArrayNear(eval_out, eval_out2, 0.001)
out = model.predict(x) @test_util.run_in_graph_and_eager_modes
out2 = new_model.predict(x)
# The model has been trained on two batches. So the tolerance is larger.
self.assertAllClose(out, out2, atol=0.01)
def test_sequential_model_saving_without_input_shape(self): def test_sequential_model_saving_without_input_shape(self):
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()
with ops.Graph().as_default(), self.cached_session(): with self.cached_session():
model = keras.models.Sequential() model = keras.models.Sequential()
model.add(keras.layers.Dense(2)) model.add(keras.layers.Dense(2))
model.add(keras.layers.RepeatVector(3)) model.add(keras.layers.RepeatVector(3))
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
model.compile( model.compile(
loss=keras.losses.MSE, loss=keras.losses.MSE,
optimizer=keras.optimizers.RMSprop(lr=0.0001), optimizer='rmsprop',
metrics=[ metrics=[
keras.metrics.categorical_accuracy, keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy() keras.metrics.CategoricalAccuracy(name='cat_acc')
], ],
weighted_metrics=[ weighted_metrics=[
keras.metrics.categorical_accuracy, keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy() keras.metrics.CategoricalAccuracy(name='cat_acc2')
], ],
sample_weight_mode='temporal') sample_weight_mode='temporal')
x = np.random.random((1, 3)) x = np.random.random((1, 3))
@ -479,12 +479,13 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
model.save(saved_model_dir, save_format=save_format) model.save(saved_model_dir, save_format=save_format)
new_model = keras.models.load_model(saved_model_dir) new_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights(
model, new_model, original_optimizer_has_iterations_variable=False) self._assert_same_weights_and_metrics(model, new_model)
out2 = new_model.predict(x) out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05) self.assertAllClose(out, out2, atol=1e-05)
@test_util.run_in_graph_and_eager_modes
def test_sequential_model_saving_without_compile(self): def test_sequential_model_saving_without_compile(self):
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()
@ -501,7 +502,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
keras.models.save_model(model, saved_model_dir, save_format=save_format) keras.models.save_model(model, saved_model_dir, save_format=save_format)
new_model = keras.models.load_model(saved_model_dir) new_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights(model, new_model) self._assert_same_weights_and_metrics(model, new_model)
out2 = new_model.predict(x) out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05) self.assertAllClose(out, out2, atol=1e-05)
@ -535,42 +536,11 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
saved_model_dir, saved_model_dir,
custom_objects={'CustomOp': CustomOp, custom_objects={'CustomOp': CustomOp,
'custom_loss': custom_loss}) 'custom_loss': custom_loss})
self._assert_same_weights(model, new_model) self._assert_same_weights_and_metrics(model, new_model)
out2 = new_model.predict(x) out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05) self.assertAllClose(out, out2, atol=1e-05)
def test_functional_model_saving(self):
saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format()
with ops.Graph().as_default(), self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
model = keras.models.Model(inputs, output)
model.compile(
loss=keras.losses.MSE,
optimizer=keras.optimizers.RMSprop(lr=0.0001),
metrics=[
keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy()
],
weighted_metrics=[
keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy()
])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
keras.models.save_model(model, saved_model_dir, save_format=save_format)
model = keras.models.load_model(saved_model_dir)
out2 = model.predict(x)
self.assertAllClose(out, out2, atol=1e-05)
def test_saving_without_compilation(self): def test_saving_without_compilation(self):
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()

View File

@ -34,6 +34,7 @@ from tensorflow.python.keras.saving.saved_model import utils
from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import metrics_utils from tensorflow.python.keras.utils import metrics_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import load as tf_load from tensorflow.python.saved_model import load as tf_load
from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.saved_model import nested_structure_coder
@ -124,6 +125,10 @@ def load(path, compile=True): # pylint: disable=redefined-builtin
if training_config is not None: if training_config is not None:
model.compile(**saving_utils.compile_args_from_training_config( model.compile(**saving_utils.compile_args_from_training_config(
training_config)) training_config))
if (not version_utils.is_v1_layer_or_model(model) and
model.outputs is not None):
model.compiled_metrics.build(model.outputs, model.outputs)
model.compiled_loss.build(model.outputs)
else: else:
logging.warning('No training configuration found in save file, so the ' logging.warning('No training configuration found in save file, so the '
'model was *not* compiled. Compile it manually.') 'model was *not* compiled. Compile it manually.')