Adding serialization unit tests for the different combinations of compile loss inputs.
PiperOrigin-RevId: 286421910 Change-Id: Icbd2cf3b7b21cf6851202e6641f9c95916e971ea
This commit is contained in:
parent
7e1680206a
commit
7def022685
@ -737,6 +737,20 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "losses_serialization_test",
|
||||
size = "medium",
|
||||
srcs = ["saving/losses_serialization_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
":keras",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "advanced_activations_test",
|
||||
size = "medium",
|
||||
|
@ -18,9 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
@ -29,15 +26,9 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import losses_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
try:
|
||||
import h5py # pylint:disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
h5py = None
|
||||
|
||||
ALL_LOSSES = [keras.losses.mean_squared_error,
|
||||
keras.losses.mean_absolute_error,
|
||||
keras.losses.mean_absolute_percentage_error,
|
||||
@ -53,20 +44,6 @@ ALL_LOSSES = [keras.losses.mean_squared_error,
|
||||
keras.losses.categorical_hinge]
|
||||
|
||||
|
||||
class _MSEMAELoss(object):
|
||||
"""Loss function with internal state, for testing serialization code."""
|
||||
|
||||
def __init__(self, mse_fraction):
|
||||
self.mse_fraction = mse_fraction
|
||||
|
||||
def __call__(self, y_true, y_pred, sample_weight=None):
|
||||
return (self.mse_fraction * keras.losses.mse(y_true, y_pred) +
|
||||
(1 - self.mse_fraction) * keras.losses.mae(y_true, y_pred))
|
||||
|
||||
def get_config(self):
|
||||
return {'mse_fraction': self.mse_fraction}
|
||||
|
||||
|
||||
class KerasLossesTest(test.TestCase):
|
||||
|
||||
def test_objective_shapes_3d(self):
|
||||
@ -200,39 +177,6 @@ class KerasLossesTest(test.TestCase):
|
||||
loss = keras.backend.eval(keras.losses.categorical_hinge(y_true, y_pred))
|
||||
self.assertAllClose(expected_loss, np.mean(loss))
|
||||
|
||||
def test_serializing_loss_class(self):
|
||||
orig_loss_class = _MSEMAELoss(0.3)
|
||||
with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
|
||||
serialized = keras.losses.serialize(orig_loss_class)
|
||||
|
||||
with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
|
||||
deserialized = keras.losses.deserialize(serialized)
|
||||
assert isinstance(deserialized, _MSEMAELoss)
|
||||
assert deserialized.mse_fraction == 0.3
|
||||
|
||||
def test_serializing_model_with_loss_class(self):
|
||||
tmpdir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, tmpdir)
|
||||
model_filename = os.path.join(tmpdir, 'custom_loss.h5')
|
||||
|
||||
with self.cached_session():
|
||||
with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
|
||||
loss = _MSEMAELoss(0.3)
|
||||
inputs = keras.layers.Input((2,))
|
||||
outputs = keras.layers.Dense(1, name='model_output')(inputs)
|
||||
model = keras.models.Model(inputs, outputs)
|
||||
model.compile(optimizer='sgd', loss={'model_output': loss})
|
||||
model.fit(np.random.rand(256, 2), np.random.rand(256, 1))
|
||||
|
||||
if h5py is None:
|
||||
return
|
||||
|
||||
model.save(model_filename)
|
||||
|
||||
with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
|
||||
loaded_model = keras.models.load_model(model_filename)
|
||||
loaded_model.predict(np.random.rand(128, 2))
|
||||
|
||||
def test_loss_wrapper(self):
|
||||
loss_fn = keras.losses.get('mse')
|
||||
mse_obj = keras.losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__)
|
||||
|
197
tensorflow/python/keras/saving/losses_serialization_test.py
Normal file
197
tensorflow/python/keras/saving/losses_serialization_test.py
Normal file
@ -0,0 +1,197 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Keras losses serialization."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras import optimizer_v2
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import losses_utils
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
try:
|
||||
import h5py # pylint:disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
h5py = None
|
||||
|
||||
|
||||
# Custom loss class
|
||||
class MyMeanAbsoluteError(losses.LossFunctionWrapper):
|
||||
|
||||
def __init__(self,
|
||||
reduction=losses_utils.ReductionV2.AUTO,
|
||||
name='mean_absolute_error'):
|
||||
super(MyMeanAbsoluteError, self).__init__(
|
||||
_my_mae, name=name, reduction=reduction)
|
||||
|
||||
|
||||
# Custom loss function
|
||||
def _my_mae(y_true, y_pred):
|
||||
return keras.backend.mean(math_ops.abs(y_pred - y_true), axis=-1)
|
||||
|
||||
|
||||
def _get_multi_io_model():
|
||||
inp_1 = layers.Input(shape=(1,), name='input_1')
|
||||
inp_2 = layers.Input(shape=(1,), name='input_2')
|
||||
d = testing_utils.Bias(name='output')
|
||||
out_1 = d(inp_1)
|
||||
out_2 = d(inp_2)
|
||||
return keras.Model([inp_1, inp_2], [out_1, out_2])
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@parameterized.named_parameters([
|
||||
dict(testcase_name='string', value='mae'),
|
||||
dict(testcase_name='built_in_fn', value=losses.mae),
|
||||
dict(testcase_name='built_in_class', value=losses.MeanAbsoluteError()),
|
||||
dict(testcase_name='custom_fn', value=_my_mae),
|
||||
dict(testcase_name='custom_class', value=MyMeanAbsoluteError()),
|
||||
dict(testcase_name='list_of_strings', value=['mae', 'mae']),
|
||||
dict(testcase_name='list_of_built_in_fns', value=[losses.mae, losses.mae]),
|
||||
dict(
|
||||
testcase_name='list_of_built_in_classes',
|
||||
value=[losses.MeanAbsoluteError(),
|
||||
losses.MeanAbsoluteError()]),
|
||||
dict(testcase_name='list_of_custom_fns', value=[_my_mae, _my_mae]),
|
||||
dict(
|
||||
testcase_name='list_of_custom_classes',
|
||||
value=[MyMeanAbsoluteError(),
|
||||
MyMeanAbsoluteError()]),
|
||||
dict(
|
||||
testcase_name='dict_of_string',
|
||||
value={
|
||||
'output': 'mae',
|
||||
'output_1': 'mae',
|
||||
}),
|
||||
dict(
|
||||
testcase_name='dict_of_built_in_fn',
|
||||
value={
|
||||
'output': losses.mae,
|
||||
'output_1': losses.mae,
|
||||
}),
|
||||
dict(
|
||||
testcase_name='dict_of_built_in_class',
|
||||
value={
|
||||
'output': losses.MeanAbsoluteError(),
|
||||
'output_1': losses.MeanAbsoluteError(),
|
||||
}),
|
||||
dict(
|
||||
testcase_name='dict_of_custom_fn',
|
||||
value={
|
||||
'output': _my_mae,
|
||||
'output_1': _my_mae
|
||||
}),
|
||||
dict(
|
||||
testcase_name='dict_of_custom_class',
|
||||
value={
|
||||
'output': MyMeanAbsoluteError(),
|
||||
'output_1': MyMeanAbsoluteError(),
|
||||
}),
|
||||
])
|
||||
class LossesSerialization(keras_parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(LossesSerialization, self).setUp()
|
||||
tmpdir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, tmpdir)
|
||||
self.model_filename = os.path.join(tmpdir, 'tmp_model_loss.h5')
|
||||
self.x = np.array([[0.], [1.], [2.]], dtype='float32')
|
||||
self.y = np.array([[0.5], [2.], [3.5]], dtype='float32')
|
||||
self.w = np.array([1.25, 0.5, 1.25], dtype='float32')
|
||||
|
||||
def test_serializing_model_with_loss_with_custom_object_scope(self, value):
|
||||
with generic_utils.custom_object_scope({
|
||||
'MyMeanAbsoluteError': MyMeanAbsoluteError,
|
||||
'_my_mae': _my_mae,
|
||||
'Bias': testing_utils.Bias,
|
||||
}):
|
||||
model = _get_multi_io_model()
|
||||
model.compile(
|
||||
optimizer_v2.gradient_descent.SGD(0.1),
|
||||
loss=value,
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
experimental_run_tf_function=testing_utils.should_run_tf_function())
|
||||
history = model.fit([self.x, self.x], [self.y, self.y],
|
||||
batch_size=3,
|
||||
epochs=3,
|
||||
sample_weight=[self.w, self.w])
|
||||
|
||||
# Assert training.
|
||||
self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3)
|
||||
eval_results = model.evaluate([self.x, self.x], [self.y, self.y],
|
||||
sample_weight=[self.w, self.w])
|
||||
|
||||
if h5py is None:
|
||||
return
|
||||
model.save(self.model_filename)
|
||||
loaded_model = keras.models.load_model(self.model_filename)
|
||||
loaded_model.predict([self.x, self.x])
|
||||
loaded_eval_results = loaded_model.evaluate(
|
||||
[self.x, self.x], [self.y, self.y], sample_weight=[self.w, self.w])
|
||||
|
||||
# Assert all evaluation results are the same.
|
||||
self.assertAllClose(eval_results, loaded_eval_results, 1e-9)
|
||||
|
||||
def test_serializing_model_with_loss_with_custom_objects(self, value):
|
||||
model = _get_multi_io_model()
|
||||
model.compile(
|
||||
optimizer_v2.gradient_descent.SGD(0.1),
|
||||
loss=value,
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
experimental_run_tf_function=testing_utils.should_run_tf_function())
|
||||
history = model.fit([self.x, self.x], [self.y, self.y],
|
||||
batch_size=3,
|
||||
epochs=3,
|
||||
sample_weight=[self.w, self.w])
|
||||
|
||||
# Assert training.
|
||||
self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3)
|
||||
eval_results = model.evaluate([self.x, self.x], [self.y, self.y],
|
||||
sample_weight=[self.w, self.w])
|
||||
|
||||
if h5py is None:
|
||||
return
|
||||
model.save(self.model_filename)
|
||||
loaded_model = keras.models.load_model(
|
||||
self.model_filename,
|
||||
custom_objects={
|
||||
'MyMeanAbsoluteError': MyMeanAbsoluteError,
|
||||
'_my_mae': _my_mae,
|
||||
'Bias': testing_utils.Bias,
|
||||
})
|
||||
loaded_model.predict([self.x, self.x])
|
||||
loaded_eval_results = loaded_model.evaluate([self.x, self.x],
|
||||
[self.y, self.y],
|
||||
sample_weight=[self.w, self.w])
|
||||
|
||||
# Assert all evaluation results are the same.
|
||||
self.assertAllClose(eval_results, loaded_eval_results, 1e-9)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user