Roll forward of cl/316247127: Make sure compiled metrics are accessible after loading from H5 or SavedModel.

PiperOrigin-RevId: 317329395
Change-Id: I33578515f36aa0ba227e75bda52966d493a4bebb
This commit is contained in:
Katherine Wu 2020-06-19 10:12:17 -07:00 committed by TensorFlower Gardener
parent 7bd8698f34
commit 4d751f9da4
8 changed files with 79 additions and 87 deletions

View File

@ -37,7 +37,7 @@ class Container(object):
def __init__(self, output_names=None):
self._output_names = output_names
def _build(self, y_pred):
def build(self, y_pred):
if self._output_names is None:
# In Subclass API, output names like 'output_1' are used for
# `Metric` names.
@ -131,9 +131,9 @@ class LossesContainer(Container):
]
return [self._loss_metric] + per_output_metrics
def _build(self, y_pred):
def build(self, y_pred):
"""One-time setup of loss objects."""
super(LossesContainer, self)._build(y_pred)
super(LossesContainer, self).build(y_pred)
self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses)
self._losses = self._conform_to_outputs(y_pred, self._losses)
@ -184,7 +184,7 @@ class LossesContainer(Container):
sample_weight = self._conform_to_outputs(y_pred, sample_weight)
if not self._built:
self._build(y_pred)
self.build(y_pred)
y_pred = nest.flatten(y_pred)
y_true = nest.flatten(y_true)
@ -295,9 +295,9 @@ class MetricsContainer(Container):
return []
return self._metrics_in_order
def _build(self, y_pred, y_true):
def build(self, y_pred, y_true):
"""One-time setup of metric objects."""
super(MetricsContainer, self)._build(y_pred)
super(MetricsContainer, self).build(y_pred)
self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
self._metrics = self._conform_to_outputs(y_pred, self._metrics)
@ -385,7 +385,7 @@ class MetricsContainer(Container):
sample_weight = self._conform_to_outputs(y_pred, sample_weight)
if not self._built:
self._build(y_pred, y_true)
self.build(y_pred, y_true)
y_pred = nest.flatten(y_pred)
y_true = nest.flatten(y_true) if y_true is not None else []

View File

@ -1007,10 +1007,12 @@ def _map_subgraph_network(inputs, outputs):
def _should_skip_first_node(layer):
"""Returns True if the first layer node should not be saved or loaded."""
# Networks start with a pre-existing node linking their input to output.
# For a sequential model, it is first created with _is_graph_network = False,
# we have to keep the _is_graph_network check here.
return isinstance(layer, Functional) and layer._is_graph_network
# Networks that are constructed with an Input layer/shape start with a
# pre-existing node linking their input to output. This node is excluded from
# the network config.
return (isinstance(layer, Functional) and
# Filter out Sequential models without an input shape.
isinstance(layer._layers[0], input_layer_module.InputLayer))
def _deserialize_keras_tensors(kwargs, layer_map):

View File

@ -436,7 +436,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
'Instead, in order to instantiate and build your '
'model, `call` your model on real tensor data (of '
'the correct dtype).')
super(Model, self).build(input_shape)
def call(self, inputs, training=None, mask=None):
@ -2417,6 +2416,12 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
self._saved_model_inputs_spec = specs
# Store the input shapes
if (self.__class__.__name__ == 'Sequential' and
self._build_input_shape is None):
self._build_input_shape = nest.map_structure(
lambda x: None if x is None else x.shape, specs)
def _assert_weights_created(self):
"""Asserts that all the weights for the model have been created.

View File

@ -630,10 +630,9 @@ class MeanMetricWrapper(Mean):
def from_config(cls, config):
# Note that while MeanMetricWrapper itself isn't public, objects of this
# class may be created and added to the model by calling model.compile.
fn = config.pop('fn', None)
if cls is MeanMetricWrapper:
fn = get(config.pop('fn'))
return cls(fn, **config)
return cls(get(fn), **config)
return super(MeanMetricWrapper, cls).from_config(config)

View File

@ -192,6 +192,7 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint
# Compile model.
model.compile(**saving_utils.compile_args_from_training_config(
training_config, custom_objects))
saving_utils.try_build_compiled_arguments(model)
# Set optimizer weights.
if 'optimizer_weights' in f:

View File

@ -26,10 +26,12 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import optimizers
@ -368,48 +370,54 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
@keras_parameterized.run_with_all_saved_model_formats
class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
class TestWholeModelSaving(keras_parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'):
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
return os.path.join(temp_dir, dirname)
def _assert_same_weights(self, model, loaded_model,
original_optimizer_has_iterations_variable=True):
"""Checks that the loaded weighs are the same as the original weights.
def _assert_same_weights_and_metrics(self, model, loaded_model):
"""Checks that the loaded weights and metrics are the same as the original.
Args:
model: original model
loaded_model: loaded model
original_optimizer_has_iterations_variable: If the original optimizer
uses an iterations variable. The loaded model will have a v2
optimizer, which always contains an iterations variable. So when
comparing the weights, the first variable in the loaded optimizer
weights may need to be ignored.
"""
self.assertAllClose(model.weights, loaded_model.weights)
if loaded_model.optimizer:
if testing_utils.get_save_format() == 'tf':
# TODO(b/153110928): Keras TF format doesn't restore optimizer weights
# currently.
return
if original_optimizer_has_iterations_variable:
self.assertAllClose(model.optimizer.weights,
loaded_model.optimizer.weights)
else:
self.assertAllClose(model.optimizer.weights,
loaded_model.optimizer.weights[1:])
self.assertAllClose(model.optimizer.weights,
loaded_model.optimizer.weights)
def test_sequential_model_saving(self):
# In V1/Graph mode, the model isn't built, so the metrics are not loaded
# immediately (requires model to be called on some data before building
# metrics).
check_metrics = tf2.enabled() and context.executing_eagerly()
if check_metrics:
self.assertAllEqual([m.name for m in model.metrics],
[m.name for m in loaded_model.metrics])
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_save_and_load(self):
saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format()
if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
return # HDF5 format currently does not allow saving classed models.
with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
model = testing_utils.get_model_from_layers(
[keras.layers.Dense(2),
keras.layers.RepeatVector(3),
keras.layers.TimeDistributed(keras.layers.Dense(3))],
input_shape=(3,))
model.compile(
loss=keras.losses.MSE,
optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001),
@ -432,43 +440,35 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
out = model.predict(x)
keras.models.save_model(model, saved_model_dir, save_format=save_format)
new_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights(model, new_model)
loaded_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights_and_metrics(model, loaded_model)
out2 = new_model.predict(x)
out2 = loaded_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05)
# test that new updates are the same with both models
model.train_on_batch(x, y)
new_model.train_on_batch(x, y)
eval_out = model.evaluate(x, y)
eval_out2 = new_model.evaluate(x, y)
eval_out2 = loaded_model.evaluate(x, y)
self.assertArrayNear(eval_out, eval_out2, 0.001)
out = model.predict(x)
out2 = new_model.predict(x)
# The model has been trained on two batches. So the tolerance is larger.
self.assertAllClose(out, out2, atol=0.01)
@test_util.run_in_graph_and_eager_modes
def test_sequential_model_saving_without_input_shape(self):
saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format()
with ops.Graph().as_default(), self.cached_session():
with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2))
model.add(keras.layers.RepeatVector(3))
model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
model.compile(
loss=keras.losses.MSE,
optimizer=keras.optimizers.RMSprop(lr=0.0001),
optimizer='rmsprop',
metrics=[
keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy()
keras.metrics.CategoricalAccuracy(name='cat_acc')
],
weighted_metrics=[
keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy()
keras.metrics.CategoricalAccuracy(name='cat_acc2')
],
sample_weight_mode='temporal')
x = np.random.random((1, 3))
@ -479,12 +479,13 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
model.save(saved_model_dir, save_format=save_format)
new_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights(
model, new_model, original_optimizer_has_iterations_variable=False)
self._assert_same_weights_and_metrics(model, new_model)
out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05)
@test_util.run_in_graph_and_eager_modes
def test_sequential_model_saving_without_compile(self):
saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format()
@ -501,7 +502,7 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
keras.models.save_model(model, saved_model_dir, save_format=save_format)
new_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights(model, new_model)
self._assert_same_weights_and_metrics(model, new_model)
out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05)
@ -535,42 +536,11 @@ class TestWholeModelSaving(test.TestCase, parameterized.TestCase):
saved_model_dir,
custom_objects={'CustomOp': CustomOp,
'custom_loss': custom_loss})
self._assert_same_weights(model, new_model)
self._assert_same_weights_and_metrics(model, new_model)
out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05)
def test_functional_model_saving(self):
saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format()
with ops.Graph().as_default(), self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
model = keras.models.Model(inputs, output)
model.compile(
loss=keras.losses.MSE,
optimizer=keras.optimizers.RMSprop(lr=0.0001),
metrics=[
keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy()
],
weighted_metrics=[
keras.metrics.categorical_accuracy,
keras.metrics.CategoricalAccuracy()
])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
keras.models.save_model(model, saved_model_dir, save_format=save_format)
model = keras.models.load_model(saved_model_dir)
out2 = model.predict(x)
self.assertAllClose(out, out2, atol=1e-05)
def test_saving_without_compilation(self):
saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format()

View File

@ -129,6 +129,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
if training_config is not None:
model.compile(**saving_utils.compile_args_from_training_config(
training_config))
saving_utils.try_build_compiled_arguments(model)
else:
logging.warning('No training configuration found in save file, so the '
'model was *not* compiled. Compile it manually.')

View File

@ -27,6 +27,7 @@ from tensorflow.python.keras import losses
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@ -307,3 +308,16 @@ def _enforce_names_consistency(specs):
if name_inconsistency:
specs = nest.map_structure(_clear_name, specs)
return specs
def try_build_compiled_arguments(model):
if (not version_utils.is_v1_layer_or_model(model) and
model.outputs is not None):
try:
model.compiled_loss.build(model.outputs)
model.compiled_metrics.build(model.outputs, model.outputs)
except: # pylint: disable=bare-except
logging.warning(
'Compiled the loaded model, but the compiled metrics have yet to '
'be built. `model.compile_metrics` will be empty until you train '
'or evaluate the model.')