Add test combinations for Keras test.

Wrt to the test annotations for eager/graph and others, TF API owner reach the consensus that combination framework is more generic and should be used for new API for creating parameterized tests.

This change also adds Keras specific combinations for existing keras_mode and model types. Updating simple_rnn_test as a show case. All the other keras tests will be updated later.

PiperOrigin-RevId: 298939205
Change-Id: I5b63b57717f79ae338cf81d8505897e95ab94b8f
This commit is contained in:
Scott Zhu 2020-03-04 14:18:56 -08:00 committed by TensorFlower Gardener
parent e21667ae73
commit 2733583b75
6 changed files with 326 additions and 6 deletions

View File

@ -148,6 +148,19 @@ py_library(
],
)
py_library(
name = "combinations",
srcs = [
"combinations.py",
],
deps = [
":testing_utils",
"//tensorflow/python:framework_combinations",
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python:tf2",
],
)
py_library(
name = "callbacks_v1",
srcs = [
@ -317,6 +330,22 @@ tf_py_test(
],
)
tf_py_test(
name = "combinations_test",
size = "small",
srcs = ["combinations_test.py"],
python_version = "PY3",
deps = [
":combinations",
":testing_utils",
"//tensorflow/python:client_testlib",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:tf2",
"//tensorflow/python/eager:context",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "constraints_test",
size = "small",

View File

@ -0,0 +1,110 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module customizes `test_combinations` for `tf.keras` related tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.python import tf2
from tensorflow.python.framework import combinations
from tensorflow.python.framework import test_combinations
from tensorflow.python.keras import testing_utils
KERAS_MODEL_TYPES = ['functional', 'subclass', 'sequential']
def keras_mode_combinations(mode=None, run_eagerly=None):
"""Returns the default test combinations for tf.keras tests.
Note that if tf2 is enabled, then v1 session test will be skipped.
Args:
mode: List of modes to run the tests. The valid options are 'graph' and
'eager'. Default to ['graph', 'eager'] if not specified. If a empty list
is provide, then the test will run under the context based on tf's
version, eg graph for v1 and eager for v2.
run_eagerly: List of `run_eagerly` value to be run with the tests.
Default to [True, False] if not specified. Note that for `graph` mode,
run_eagerly value will only be False.
Returns:
A list contains all the combinations to be used to generate test cases.
"""
if mode is None:
mode = ['eager'] if tf2.enabled() else ['graph', 'eager']
if run_eagerly is None:
run_eagerly = [True, False]
result = []
if 'eager' in mode:
result += combinations.combine(mode=['eager'], run_eagerly=run_eagerly)
if 'graph' in mode:
result += combinations.combine(mode=['graph'], run_eagerly=[False])
return result
def keras_model_type_combinations():
return combinations.combine(model_type=KERAS_MODEL_TYPES)
class KerasModeCombination(test_combinations.TestCombination):
"""Combination for Keras test mode.
It by default includes v1_session, v2_eager and v2_tf_function.
"""
def context_managers(self, kwargs):
run_eagerly = kwargs.pop('run_eagerly', None)
if run_eagerly is not None:
return [testing_utils.run_eagerly_scope(run_eagerly)]
else:
return []
def parameter_modifiers(self):
return [test_combinations.OptionalParameter('run_eagerly')]
class KerasModelTypeCombination(test_combinations.TestCombination):
"""Combination for Keras model types when doing model test.
It by default includes 'functional', 'subclass', 'sequential'.
Various methods in `testing_utils` to get models will auto-generate a model
of the currently active Keras model type. This allows unittests to confirm
the equivalence between different Keras models.
"""
def context_managers(self, kwargs):
model_type = kwargs.pop('model_type', None)
if model_type in KERAS_MODEL_TYPES:
return [testing_utils.model_type_scope(model_type)]
else:
return []
def parameter_modifiers(self):
return [test_combinations.OptionalParameter('model_type')]
_defaults = combinations.generate.keywords['test_combinations']
generate = functools.partial(
combinations.generate,
test_combinations=_defaults +
(KerasModeCombination(), KerasModelTypeCombination()))
combine = test_combinations.combine
times = test_combinations.times
NamedObject = test_combinations.NamedObject

View File

@ -0,0 +1,178 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras combinations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from absl.testing import parameterized
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.keras import combinations
from tensorflow.python.keras import models as keras_models
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
class CombinationsTest(test.TestCase):
def test_run_all_keras_modes(self):
test_params = []
class ExampleTest(parameterized.TestCase):
def runTest(self):
pass
@combinations.generate(combinations.keras_mode_combinations())
def testBody(self):
mode = "eager" if context.executing_eagerly() else "graph"
should_run_eagerly = testing_utils.should_run_eagerly()
test_params.append((mode, should_run_eagerly))
e = ExampleTest()
if not tf2.enabled():
e.testBody_test_mode_graph_runeagerly_False()
e.testBody_test_mode_eager_runeagerly_True()
e.testBody_test_mode_eager_runeagerly_False()
if not tf2.enabled():
self.assertLen(test_params, 3)
self.assertAllEqual(test_params, [
("graph", False),
("eager", True),
("eager", False),
])
ts = unittest.makeSuite(ExampleTest)
res = unittest.TestResult()
ts.run(res)
self.assertLen(test_params, 6)
else:
self.assertLen(test_params, 2)
self.assertAllEqual(test_params, [
("eager", True),
("eager", False),
])
ts = unittest.makeSuite(ExampleTest)
res = unittest.TestResult()
ts.run(res)
self.assertLen(test_params, 4)
def test_generate_keras_mode_eager_only(self):
result = combinations.keras_mode_combinations(mode=["eager"])
self.assertLen(result, 2)
self.assertEqual(result[0], {"mode": "eager", "run_eagerly": True})
self.assertEqual(result[1], {"mode": "eager", "run_eagerly": False})
def test_generate_keras_mode_skip_run_eagerly(self):
result = combinations.keras_mode_combinations(run_eagerly=[False])
if tf2.enabled():
self.assertLen(result, 1)
self.assertEqual(result[0], {"mode": "eager", "run_eagerly": False})
else:
self.assertLen(result, 2)
self.assertEqual(result[0], {"mode": "eager", "run_eagerly": False})
self.assertEqual(result[1], {"mode": "graph", "run_eagerly": False})
def test_run_all_keras_model_types(self):
model_types = []
models = []
class ExampleTest(parameterized.TestCase):
def runTest(self):
pass
@combinations.generate(combinations.keras_model_type_combinations())
def testBody(self):
model_types.append(testing_utils.get_model_type())
models.append(testing_utils.get_small_mlp(1, 4, input_dim=3))
e = ExampleTest()
e.testBody_test_modeltype_functional()
e.testBody_test_modeltype_subclass()
e.testBody_test_modeltype_sequential()
self.assertLen(model_types, 3)
self.assertAllEqual(model_types, [
"functional",
"subclass",
"sequential"
])
# Validate that the models are what they should be
self.assertTrue(models[0]._is_graph_network)
self.assertFalse(models[1]._is_graph_network)
self.assertNotIsInstance(models[0], keras_models.Sequential)
self.assertNotIsInstance(models[1], keras_models.Sequential)
self.assertIsInstance(models[2], keras_models.Sequential)
ts = unittest.makeSuite(ExampleTest)
res = unittest.TestResult()
ts.run(res)
self.assertLen(model_types, 6)
def test_combine_combinations(self):
test_cases = []
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(),
combinations.keras_model_type_combinations()))
class ExampleTest(parameterized.TestCase):
def runTest(self):
pass
@parameterized.named_parameters(dict(testcase_name="_arg",
arg=True))
def testBody(self, arg):
del arg
mode = "eager" if context.executing_eagerly() else "graph"
should_run_eagerly = testing_utils.should_run_eagerly()
test_cases.append((mode, should_run_eagerly,
testing_utils.get_model_type()))
ts = unittest.makeSuite(ExampleTest)
res = unittest.TestResult()
ts.run(res)
expected_combinations = [
("eager", False, "functional"),
("eager", False, "sequential"),
("eager", False, "subclass"),
("eager", True, "functional"),
("eager", True, "sequential"),
("eager", True, "subclass"),
]
if not tf2.enabled():
expected_combinations.extend([
("graph", False, "functional"),
("graph", False, "sequential"),
("graph", False, "subclass"),
])
self.assertAllEqual(sorted(test_cases), expected_combinations)
if __name__ == "__main__":
test.main()

View File

@ -638,6 +638,7 @@ tf_py_test(
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],

View File

@ -18,20 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import combinations
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
@keras_parameterized.run_all_keras_modes
class SimpleRNNLayerTest(keras_parameterized.TestCase):
@combinations.generate(combinations.keras_mode_combinations())
class SimpleRNNLayerTest(test.TestCase, parameterized.TestCase):
def test_return_sequences_SimpleRNN(self):
num_samples = 2
@ -145,14 +146,14 @@ class SimpleRNNLayerTest(keras_parameterized.TestCase):
bias_regularizer='l2',
activity_regularizer='l1')
layer.build((None, None, 2))
self.assertEqual(len(layer.losses), 3)
self.assertLen(layer.losses, 3)
x = keras.backend.variable(np.ones((2, 3, 2)))
layer(x)
if context.executing_eagerly():
self.assertEqual(len(layer.losses), 4)
self.assertLen(layer.losses, 4)
else:
self.assertEqual(len(layer.get_losses_for(x)), 1)
self.assertLen(layer.get_losses_for(x), 1)
def test_statefulness_SimpleRNN(self):
num_samples = 2

View File

@ -113,6 +113,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:multi_process_runner",
"//tensorflow/python/eager:eager_pip",
"//tensorflow/python/keras:combinations",
"//tensorflow/python/keras/layers/preprocessing:preprocessing_test_utils",
"//tensorflow/python/keras/distribute:distribute_strategy_test_lib",
"//tensorflow/python/keras/distribute:multi_worker_testing_utils",