Roll forward of cl/316247127: Make sure compiled metrics are accessible after loading from H5 or SavedModel.

PiperOrigin-RevId: 317329395
Change-Id: I33578515f36aa0ba227e75bda52966d493a4bebb
This commit is contained in:
Katherine Wu 2020-06-19 10:12:17 -07:00 committed by TensorFlower Gardener
parent 7bd8698f34
commit 4d751f9da4
8 changed files with 79 additions and 87 deletions

View File

@ -37,7 +37,7 @@ class Container(object):
def __init__(self, output_names=None): def __init__(self, output_names=None):
self._output_names = output_names self._output_names = output_names
def _build(self, y_pred): def build(self, y_pred):
if self._output_names is None: if self._output_names is None:
# In Subclass API, output names like 'output_1' are used for # In Subclass API, output names like 'output_1' are used for
# `Metric` names. # `Metric` names.
@ -131,9 +131,9 @@ class LossesContainer(Container):
] ]
return [self._loss_metric] + per_output_metrics return [self._loss_metric] + per_output_metrics
def _build(self, y_pred): def build(self, y_pred):
"""One-time setup of loss objects.""" """One-time setup of loss objects."""
super(LossesContainer, self)._build(y_pred) super(LossesContainer, self).build(y_pred)
self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses) self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses)
self._losses = self._conform_to_outputs(y_pred, self._losses) self._losses = self._conform_to_outputs(y_pred, self._losses)
@ -184,7 +184,7 @@ class LossesContainer(Container):
sample_weight = self._conform_to_outputs(y_pred, sample_weight) sample_weight = self._conform_to_outputs(y_pred, sample_weight)
if not self._built: if not self._built:
self._build(y_pred) self.build(y_pred)
y_pred = nest.flatten(y_pred) y_pred = nest.flatten(y_pred)
y_true = nest.flatten(y_true) y_true = nest.flatten(y_true)
@ -295,9 +295,9 @@ class MetricsContainer(Container):
return [] return []
return self._metrics_in_order return self._metrics_in_order
def _build(self, y_pred, y_true): def build(self, y_pred, y_true):
"""One-time setup of metric objects.""" """One-time setup of metric objects."""
super(MetricsContainer, self)._build(y_pred) super(MetricsContainer, self).build(y_pred)
self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics) self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
self._metrics = self._conform_to_outputs(y_pred, self._metrics) self._metrics = self._conform_to_outputs(y_pred, self._metrics)
@ -385,7 +385,7 @@ class MetricsContainer(Container):
sample_weight = self._conform_to_outputs(y_pred, sample_weight) sample_weight = self._conform_to_outputs(y_pred, sample_weight)
if not self._built: if not self._built:
self._build(y_pred, y_true) self.build(y_pred, y_true)
y_pred = nest.flatten(y_pred) y_pred = nest.flatten(y_pred)
y_true = nest.flatten(y_true) if y_true is not None else [] y_true = nest.flatten(y_true) if y_true is not None else []

View File

@ -1007,10 +1007,12 @@ def _map_subgraph_network(inputs, outputs):
def _should_skip_first_node(layer): def _should_skip_first_node(layer):
"""Returns True if the first layer node should not be saved or loaded.""" """Returns True if the first layer node should not be saved or loaded."""
# Networks start with a pre-existing node linking their input to output. # Networks that are constructed with an Input layer/shape start with a
# For a sequential model, it is first created with _is_graph_network = False, # pre-existing node linking their input to output. This node is excluded from
# we have to keep the _is_graph_network check here. # the network config.
return isinstance(layer, Functional) and layer._is_graph_network return (isinstance(layer, Functional) and
# Filter out Sequential models without an input shape.
isinstance(layer._layers[0], input_layer_module.InputLayer))
def _deserialize_keras_tensors(kwargs, layer_map): def _deserialize_keras_tensors(kwargs, layer_map):

View File

@ -436,7 +436,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
'Instead, in order to instantiate and build your ' 'Instead, in order to instantiate and build your '
'model, `call` your model on real tensor data (of ' 'model, `call` your model on real tensor data (of '
'the correct dtype).') 'the correct dtype).')
super(Model, self).build(input_shape) super(Model, self).build(input_shape)
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
@ -2417,6 +2416,12 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
self._saved_model_inputs_spec = specs self._saved_model_inputs_spec = specs
# Store the input shapes
if (self.__class__.__name__ == 'Sequential' and
self._build_input_shape is None):
self._build_input_shape = nest.map_structure(
lambda x: None if x is None else x.shape, specs)
def _assert_weights_created(self): def _assert_weights_created(self):
"""Asserts that all the weights for the model have been created. """Asserts that all the weights for the model have been created.

View File

@ -630,10 +630,9 @@ class MeanMetricWrapper(Mean):
def from_config(cls, config): def from_config(cls, config):
# Note that while MeanMetricWrapper itself isn't public, objects of this # Note that while MeanMetricWrapper itself isn't public, objects of this
# class may be created and added to the model by calling model.compile. # class may be created and added to the model by calling model.compile.
fn = config.pop('fn', None)
if cls is MeanMetricWrapper: if cls is MeanMetricWrapper:
fn = get(config.pop('fn')) return cls(get(fn), **config)
return cls(fn, **config)
return super(MeanMetricWrapper, cls).from_config(config) return super(MeanMetricWrapper, cls).from_config(config)

View File

@ -192,6 +192,7 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint
# Compile model. # Compile model.
model.compile(**saving_utils.compile_args_from_training_config( model.compile(**saving_utils.compile_args_from_training_config(
training_config, custom_objects)) training_config, custom_objects))
saving_utils.try_build_compiled_arguments(model)
# Set optimizer weights. # Set optimizer weights.
if 'optimizer_weights' in f: if 'optimizer_weights' in f:

View File

@ -26,10 +26,12 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import optimizers from tensorflow.python.keras import optimizers
@ -368,48 +370,54 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
@keras_parameterized.run_with_all_saved_model_formats @keras_parameterized.run_with_all_saved_model_formats
class TestWholeModelSaving(test.TestCase, parameterized.TestCase): class TestWholeModelSaving(keras_parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'): def _save_model_dir(self, dirname='saved_model'):
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
return os.path.join(temp_dir, dirname) return os.path.join(temp_dir, dirname)
def _assert_same_weights(self, model, loaded_model, def _assert_same_weights_and_metrics(self, model, loaded_model):
original_optimizer_has_iterations_variable=True): """Checks that the loaded weights and metrics are the same as the original.
"""Checks that the loaded weighs are the same as the original weights.
Args: Args:
model: original model model: original model
loaded_model: loaded model loaded_model: loaded model
original_optimizer_has_iterations_variable: If the original optimizer
uses an iterations variable. The loaded model will have a v2
optimizer, which always contains an iterations variable. So when
comparing the weights, the first variable in the loaded optimizer
weights may need to be ignored.
""" """
self.assertAllClose(model.weights, loaded_model.weights) self.assertAllClose(model.weights, loaded_model.weights)
if loaded_model.optimizer: if loaded_model.optimizer:
if testing_utils.get_save_format() == 'tf': if testing_utils.get_save_format() == 'tf':
# TODO(b/153110928): Keras TF format doesn't restore optimizer weights # TODO(b/153110928): Keras TF format doesn't restore optimizer weights
# currently. # currently.
return return
if original_optimizer_has_iterations_variable: self.assertAllClose(model.optimizer.weights,
self.assertAllClose(model.optimizer.weights, loaded_model.optimizer.weights)
loaded_model.optimizer.weights)
else:
self.assertAllClose(model.optimizer.weights,
loaded_model.optimizer.weights[1:])
def test_sequential_model_saving(self): # In V1/Graph mode, the model isn't built, so the metrics are not loaded
# immediately (requires model to be called on some data before building
# metrics).
check_metrics = tf2.enabled() and context.executing_eagerly()
if check_metrics:
self.assertAllEqual([m.name for m in model.metrics],
[m.name for m in loaded_model.metrics])
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_save_and_load(self):
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()
if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
return # HDF5 format currently does not allow saving classed models.
with self.cached_session(): with self.cached_session():
model = keras.models.Sequential() model = testing_utils.get_model_from_layers(
model.add(keras.layers.Dense(2, input_shape=(3,))) [keras.layers.Dense(2),
model.add(keras.layers.RepeatVector(3)) keras.layers.RepeatVector(3),
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) keras.layers.TimeDistributed(keras.layers.Dense(3))],
input_shape=(3,))
model.compile( model.compile(
loss=keras.losses.MSE, loss=keras.losses.MSE,
optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001), optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001),
@ -432,43 +440,35 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
out = model.predict(x) out = model.predict(x)
keras.models.save_model(model, saved_model_dir, save_format=save_format) keras.models.save_model(model, saved_model_dir, save_format=save_format)
new_model = keras.models.load_model(saved_model_dir) loaded_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights(model, new_model) self._assert_same_weights_and_metrics(model, loaded_model)
out2 = new_model.predict(x) out2 = loaded_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05) self.assertAllClose(out, out2, atol=1e-05)
# test that new updates are the same with both models
model.train_on_batch(x, y)
new_model.train_on_batch(x, y)
eval_out = model.evaluate(x, y) eval_out = model.evaluate(x, y)
eval_out2 = new_model.evaluate(x, y) eval_out2 = loaded_model.evaluate(x, y)
self.assertArrayNear(eval_out, eval_out2, 0.001) self.assertArrayNear(eval_out, eval_out2, 0.001)
out = model.predict(x) @test_util.run_in_graph_and_eager_modes
out2 = new_model.predict(x)
# The model has been trained on two batches. So the tolerance is larger.
self.assertAllClose(out, out2, atol=0.01)
def test_sequential_model_saving_without_input_shape(self): def test_sequential_model_saving_without_input_shape(self):
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()
with ops.Graph().as_default(), self.cached_session(): with self.cached_session():
model = keras.models.Sequential() model = keras.models.Sequential()
model.add(keras.layers.Dense(2)) model.add(keras.layers.Dense(2))
model.add(keras.layers.RepeatVector(3)) model.add(keras.layers.RepeatVector(3))
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
model.compile( model.compile(
loss=keras.losses.MSE, loss=keras.losses.MSE,
optimizer=keras.optimizers.RMSprop(lr=0.0001), optimizer='rmsprop',
metrics=[ metrics=[
keras.metrics.categorical_accuracy, keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy() keras.metrics.CategoricalAccuracy(name='cat_acc')
], ],
weighted_metrics=[ weighted_metrics=[
keras.metrics.categorical_accuracy, keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy() keras.metrics.CategoricalAccuracy(name='cat_acc2')
], ],
sample_weight_mode='temporal') sample_weight_mode='temporal')
x = np.random.random((1, 3)) x = np.random.random((1, 3))
@ -479,12 +479,13 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
model.save(saved_model_dir, save_format=save_format) model.save(saved_model_dir, save_format=save_format)
new_model = keras.models.load_model(saved_model_dir) new_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights(
model, new_model, original_optimizer_has_iterations_variable=False) self._assert_same_weights_and_metrics(model, new_model)
out2 = new_model.predict(x) out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05) self.assertAllClose(out, out2, atol=1e-05)
@test_util.run_in_graph_and_eager_modes
def test_sequential_model_saving_without_compile(self): def test_sequential_model_saving_without_compile(self):
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()
@ -501,7 +502,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
keras.models.save_model(model, saved_model_dir, save_format=save_format) keras.models.save_model(model, saved_model_dir, save_format=save_format)
new_model = keras.models.load_model(saved_model_dir) new_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights(model, new_model) self._assert_same_weights_and_metrics(model, new_model)
out2 = new_model.predict(x) out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05) self.assertAllClose(out, out2, atol=1e-05)
@ -535,42 +536,11 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
saved_model_dir, saved_model_dir,
custom_objects={'CustomOp': CustomOp, custom_objects={'CustomOp': CustomOp,
'custom_loss': custom_loss}) 'custom_loss': custom_loss})
self._assert_same_weights(model, new_model) self._assert_same_weights_and_metrics(model, new_model)
out2 = new_model.predict(x) out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05) self.assertAllClose(out, out2, atol=1e-05)
def test_functional_model_saving(self):
saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format()
with ops.Graph().as_default(), self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
model = keras.models.Model(inputs, output)
model.compile(
loss=keras.losses.MSE,
optimizer=keras.optimizers.RMSprop(lr=0.0001),
metrics=[
keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy()
],
weighted_metrics=[
keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy()
])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
keras.models.save_model(model, saved_model_dir, save_format=save_format)
model = keras.models.load_model(saved_model_dir)
out2 = model.predict(x)
self.assertAllClose(out, out2, atol=1e-05)
def test_saving_without_compilation(self): def test_saving_without_compilation(self):
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()

View File

@ -129,6 +129,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
if training_config is not None: if training_config is not None:
model.compile(**saving_utils.compile_args_from_training_config( model.compile(**saving_utils.compile_args_from_training_config(
training_config)) training_config))
saving_utils.try_build_compiled_arguments(model)
else: else:
logging.warning('No training configuration found in save file, so the ' logging.warning('No training configuration found in save file, so the '
'model was *not* compiled. Compile it manually.') 'model was *not* compiled. Compile it manually.')

View File

@ -27,6 +27,7 @@ from tensorflow.python.keras import losses
from tensorflow.python.keras import optimizers from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -307,3 +308,16 @@ def _enforce_names_consistency(specs):
if name_inconsistency: if name_inconsistency:
specs = nest.map_structure(_clear_name, specs) specs = nest.map_structure(_clear_name, specs)
return specs return specs
def try_build_compiled_arguments(model):
if (not version_utils.is_v1_layer_or_model(model) and
model.outputs is not None):
try:
model.compiled_loss.build(model.outputs)
model.compiled_metrics.build(model.outputs, model.outputs)
except: # pylint: disable=bare-except
logging.warning(
'Compiled the loaded model, but the compiled metrics have yet to '
'be built. `model.compile_metrics` will be empty until you train '
'or evaluate the model.')