From 7def022685f47a21715e051be02633f539e6df7c Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Thu, 19 Dec 2019 10:48:11 -0800 Subject: [PATCH] Adding serialization unit tests for the different combinations of compile loss inputs. PiperOrigin-RevId: 286421910 Change-Id: Icbd2cf3b7b21cf6851202e6641f9c95916e971ea --- tensorflow/python/keras/BUILD | 14 ++ tensorflow/python/keras/losses_test.py | 56 ----- .../keras/saving/losses_serialization_test.py | 197 ++++++++++++++++++ 3 files changed, 211 insertions(+), 56 deletions(-) create mode 100644 tensorflow/python/keras/saving/losses_serialization_test.py diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 1c14fb1d678..7c52dea48d6 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -737,6 +737,20 @@ tf_py_test( ], ) +tf_py_test( + name = "losses_serialization_test", + size = "medium", + srcs = ["saving/losses_serialization_test.py"], + python_version = "PY3", + shard_count = 4, + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + tf_py_test( name = "advanced_activations_test", size = "medium", diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index 5776ebd0b4e..3a500bf22d9 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -18,9 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import shutil - import numpy as np from tensorflow.python import keras @@ -29,15 +26,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import losses_utils from tensorflow.python.platform import test -try: - import h5py # pylint:disable=g-import-not-at-top -except ImportError: - h5py = None - ALL_LOSSES = [keras.losses.mean_squared_error, keras.losses.mean_absolute_error, keras.losses.mean_absolute_percentage_error, @@ -53,20 +44,6 @@ ALL_LOSSES = [keras.losses.mean_squared_error, keras.losses.categorical_hinge] -class _MSEMAELoss(object): - """Loss function with internal state, for testing serialization code.""" - - def __init__(self, mse_fraction): - self.mse_fraction = mse_fraction - - def __call__(self, y_true, y_pred, sample_weight=None): - return (self.mse_fraction * keras.losses.mse(y_true, y_pred) + - (1 - self.mse_fraction) * keras.losses.mae(y_true, y_pred)) - - def get_config(self): - return {'mse_fraction': self.mse_fraction} - - class KerasLossesTest(test.TestCase): def test_objective_shapes_3d(self): @@ -200,39 +177,6 @@ class KerasLossesTest(test.TestCase): loss = keras.backend.eval(keras.losses.categorical_hinge(y_true, y_pred)) self.assertAllClose(expected_loss, np.mean(loss)) - def test_serializing_loss_class(self): - orig_loss_class = _MSEMAELoss(0.3) - with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): - serialized = keras.losses.serialize(orig_loss_class) - - with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): - deserialized = keras.losses.deserialize(serialized) - assert isinstance(deserialized, _MSEMAELoss) - assert deserialized.mse_fraction == 0.3 - - def test_serializing_model_with_loss_class(self): - tmpdir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, tmpdir) - model_filename = os.path.join(tmpdir, 'custom_loss.h5') - - with self.cached_session(): - with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): - loss = _MSEMAELoss(0.3) - inputs = keras.layers.Input((2,)) - outputs = keras.layers.Dense(1, name='model_output')(inputs) - model = keras.models.Model(inputs, outputs) - model.compile(optimizer='sgd', loss={'model_output': loss}) - model.fit(np.random.rand(256, 2), np.random.rand(256, 1)) - - if h5py is None: - return - - model.save(model_filename) - - with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): - loaded_model = keras.models.load_model(model_filename) - loaded_model.predict(np.random.rand(128, 2)) - def test_loss_wrapper(self): loss_fn = keras.losses.get('mse') mse_obj = keras.losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__) diff --git a/tensorflow/python/keras/saving/losses_serialization_test.py b/tensorflow/python/keras/saving/losses_serialization_test.py new file mode 100644 index 00000000000..61851b809de --- /dev/null +++ b/tensorflow/python/keras/saving/losses_serialization_test.py @@ -0,0 +1,197 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Keras losses serialization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +from absl.testing import parameterized +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras import layers +from tensorflow.python.keras import losses +from tensorflow.python.keras import optimizer_v2 +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import losses_utils +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + +try: + import h5py # pylint:disable=g-import-not-at-top +except ImportError: + h5py = None + + +# Custom loss class +class MyMeanAbsoluteError(losses.LossFunctionWrapper): + + def __init__(self, + reduction=losses_utils.ReductionV2.AUTO, + name='mean_absolute_error'): + super(MyMeanAbsoluteError, self).__init__( + _my_mae, name=name, reduction=reduction) + + +# Custom loss function +def _my_mae(y_true, y_pred): + return keras.backend.mean(math_ops.abs(y_pred - y_true), axis=-1) + + +def _get_multi_io_model(): + inp_1 = layers.Input(shape=(1,), name='input_1') + inp_2 = layers.Input(shape=(1,), name='input_2') + d = testing_utils.Bias(name='output') + out_1 = d(inp_1) + out_2 = d(inp_2) + return keras.Model([inp_1, inp_2], [out_1, out_2]) + + +@keras_parameterized.run_all_keras_modes +@parameterized.named_parameters([ + dict(testcase_name='string', value='mae'), + dict(testcase_name='built_in_fn', value=losses.mae), + dict(testcase_name='built_in_class', value=losses.MeanAbsoluteError()), + dict(testcase_name='custom_fn', value=_my_mae), + dict(testcase_name='custom_class', value=MyMeanAbsoluteError()), + dict(testcase_name='list_of_strings', value=['mae', 'mae']), + dict(testcase_name='list_of_built_in_fns', value=[losses.mae, losses.mae]), + dict( + testcase_name='list_of_built_in_classes', + value=[losses.MeanAbsoluteError(), + losses.MeanAbsoluteError()]), + dict(testcase_name='list_of_custom_fns', value=[_my_mae, _my_mae]), + dict( + testcase_name='list_of_custom_classes', + value=[MyMeanAbsoluteError(), + MyMeanAbsoluteError()]), + dict( + testcase_name='dict_of_string', + value={ + 'output': 'mae', + 'output_1': 'mae', + }), + dict( + testcase_name='dict_of_built_in_fn', + value={ + 'output': losses.mae, + 'output_1': losses.mae, + }), + dict( + testcase_name='dict_of_built_in_class', + value={ + 'output': losses.MeanAbsoluteError(), + 'output_1': losses.MeanAbsoluteError(), + }), + dict( + testcase_name='dict_of_custom_fn', + value={ + 'output': _my_mae, + 'output_1': _my_mae + }), + dict( + testcase_name='dict_of_custom_class', + value={ + 'output': MyMeanAbsoluteError(), + 'output_1': MyMeanAbsoluteError(), + }), +]) +class LossesSerialization(keras_parameterized.TestCase): + + def setUp(self): + super(LossesSerialization, self).setUp() + tmpdir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, tmpdir) + self.model_filename = os.path.join(tmpdir, 'tmp_model_loss.h5') + self.x = np.array([[0.], [1.], [2.]], dtype='float32') + self.y = np.array([[0.5], [2.], [3.5]], dtype='float32') + self.w = np.array([1.25, 0.5, 1.25], dtype='float32') + + def test_serializing_model_with_loss_with_custom_object_scope(self, value): + with generic_utils.custom_object_scope({ + 'MyMeanAbsoluteError': MyMeanAbsoluteError, + '_my_mae': _my_mae, + 'Bias': testing_utils.Bias, + }): + model = _get_multi_io_model() + model.compile( + optimizer_v2.gradient_descent.SGD(0.1), + loss=value, + run_eagerly=testing_utils.should_run_eagerly(), + experimental_run_tf_function=testing_utils.should_run_tf_function()) + history = model.fit([self.x, self.x], [self.y, self.y], + batch_size=3, + epochs=3, + sample_weight=[self.w, self.w]) + + # Assert training. + self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3) + eval_results = model.evaluate([self.x, self.x], [self.y, self.y], + sample_weight=[self.w, self.w]) + + if h5py is None: + return + model.save(self.model_filename) + loaded_model = keras.models.load_model(self.model_filename) + loaded_model.predict([self.x, self.x]) + loaded_eval_results = loaded_model.evaluate( + [self.x, self.x], [self.y, self.y], sample_weight=[self.w, self.w]) + + # Assert all evaluation results are the same. + self.assertAllClose(eval_results, loaded_eval_results, 1e-9) + + def test_serializing_model_with_loss_with_custom_objects(self, value): + model = _get_multi_io_model() + model.compile( + optimizer_v2.gradient_descent.SGD(0.1), + loss=value, + run_eagerly=testing_utils.should_run_eagerly(), + experimental_run_tf_function=testing_utils.should_run_tf_function()) + history = model.fit([self.x, self.x], [self.y, self.y], + batch_size=3, + epochs=3, + sample_weight=[self.w, self.w]) + + # Assert training. + self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3) + eval_results = model.evaluate([self.x, self.x], [self.y, self.y], + sample_weight=[self.w, self.w]) + + if h5py is None: + return + model.save(self.model_filename) + loaded_model = keras.models.load_model( + self.model_filename, + custom_objects={ + 'MyMeanAbsoluteError': MyMeanAbsoluteError, + '_my_mae': _my_mae, + 'Bias': testing_utils.Bias, + }) + loaded_model.predict([self.x, self.x]) + loaded_eval_results = loaded_model.evaluate([self.x, self.x], + [self.y, self.y], + sample_weight=[self.w, self.w]) + + # Assert all evaluation results are the same. + self.assertAllClose(eval_results, loaded_eval_results, 1e-9) + + +if __name__ == '__main__': + test.main()