Make custom metrics and losses saveable with the Functional API
Currently metrics added through `model.add_metric` are not saveable without adding `AddMetric` as custom object. To that end, `AddMetric` is now included in the default layers. An E2E test was added ensure correctness and prevent a regression. PiperOrigin-RevId: 256415004
This commit is contained in:
parent
84bc93d748
commit
1ad6aae48d
@ -23,6 +23,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.keras.engine.base_layer import AddLoss
|
from tensorflow.python.keras.engine.base_layer import AddLoss
|
||||||
|
from tensorflow.python.keras.engine.base_layer import AddMetric
|
||||||
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
|
||||||
from tensorflow.python.keras.engine.input_layer import Input
|
from tensorflow.python.keras.engine.input_layer import Input
|
||||||
from tensorflow.python.keras.engine.input_layer import InputLayer
|
from tensorflow.python.keras.engine.input_layer import InputLayer
|
||||||
|
@ -807,6 +807,48 @@ class TestWholeModelSaving(test.TestCase):
|
|||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$')
|
h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$')
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_functional_model_with_custom_loss_and_metric(self):
|
||||||
|
if h5py is None:
|
||||||
|
self.skipTest('h5py required to run this test')
|
||||||
|
|
||||||
|
def _make_model():
|
||||||
|
inputs = keras.Input(shape=(4,))
|
||||||
|
x = keras.layers.Dense(8, activation='relu')(inputs)
|
||||||
|
y = keras.layers.Dense(3, activation='softmax')(x)
|
||||||
|
custom_loss = keras.layers.Lambda(lambda x: keras.backend.sum(x * x))(x)
|
||||||
|
# Connect the loss to the network.
|
||||||
|
outputs = keras.layers.Lambda(lambda x: x[0])((y, custom_loss))
|
||||||
|
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
model.add_loss(custom_loss)
|
||||||
|
model.add_metric(custom_loss, aggregation='mean', name='custom_loss')
|
||||||
|
return model
|
||||||
|
|
||||||
|
model = _make_model()
|
||||||
|
model.compile(
|
||||||
|
loss=keras.losses.SparseCategoricalCrossentropy(),
|
||||||
|
optimizer=optimizers.gradient_descent_v2.SGD(),
|
||||||
|
metrics=[keras.metrics.SparseCategoricalCrossentropy()])
|
||||||
|
x = np.random.normal(size=(32, 4))
|
||||||
|
y = np.random.randint(0, 3, size=32)
|
||||||
|
model.train_on_batch(x, y)
|
||||||
|
evaluation_results = model.evaluate(x, y)
|
||||||
|
# Save and reload model.
|
||||||
|
model_path = os.path.join(self.get_temp_dir(), 'model.h5')
|
||||||
|
model.save(model_path)
|
||||||
|
del model # Prevent misuse.
|
||||||
|
loaded_model = keras.models.load_model(model_path)
|
||||||
|
os.remove(model_path)
|
||||||
|
# Assert all evaluation results are the same.
|
||||||
|
self.assertAllClose(evaluation_results, loaded_model.evaluate(x, y), 1e-9)
|
||||||
|
# Check correctness of the loss calculation.
|
||||||
|
self.assertAllGreater(evaluation_results, 0.)
|
||||||
|
evaluation_results = dict(
|
||||||
|
zip(loaded_model.metrics_names, evaluation_results))
|
||||||
|
self.assertNear(
|
||||||
|
evaluation_results['sparse_categorical_crossentropy'] +
|
||||||
|
evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6)
|
||||||
|
|
||||||
|
|
||||||
class SubclassedModel(training.Model):
|
class SubclassedModel(training.Model):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user