Add test combinations for Keras test.
Wrt to the test annotations for eager/graph and others, TF API owner reach the consensus that combination framework is more generic and should be used for new API for creating parameterized tests. This change also adds Keras specific combinations for existing keras_mode and model types. Updating simple_rnn_test as a show case. All the other keras tests will be updated later. PiperOrigin-RevId: 298939205 Change-Id: I5b63b57717f79ae338cf81d8505897e95ab94b8f
This commit is contained in:
parent
e21667ae73
commit
2733583b75
@ -148,6 +148,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "combinations",
|
||||
srcs = [
|
||||
"combinations.py",
|
||||
],
|
||||
deps = [
|
||||
":testing_utils",
|
||||
"//tensorflow/python:framework_combinations",
|
||||
"//tensorflow/python:framework_test_combinations_lib",
|
||||
"//tensorflow/python:tf2",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "callbacks_v1",
|
||||
srcs = [
|
||||
@ -317,6 +330,22 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "combinations_test",
|
||||
size = "small",
|
||||
srcs = ["combinations_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":combinations",
|
||||
":testing_utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/eager:context",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "constraints_test",
|
||||
size = "small",
|
||||
|
110
tensorflow/python/keras/combinations.py
Normal file
110
tensorflow/python/keras/combinations.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""This module customizes `test_combinations` for `tf.keras` related tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import test_combinations
|
||||
from tensorflow.python.keras import testing_utils
|
||||
|
||||
KERAS_MODEL_TYPES = ['functional', 'subclass', 'sequential']
|
||||
|
||||
|
||||
def keras_mode_combinations(mode=None, run_eagerly=None):
|
||||
"""Returns the default test combinations for tf.keras tests.
|
||||
|
||||
Note that if tf2 is enabled, then v1 session test will be skipped.
|
||||
|
||||
Args:
|
||||
mode: List of modes to run the tests. The valid options are 'graph' and
|
||||
'eager'. Default to ['graph', 'eager'] if not specified. If a empty list
|
||||
is provide, then the test will run under the context based on tf's
|
||||
version, eg graph for v1 and eager for v2.
|
||||
run_eagerly: List of `run_eagerly` value to be run with the tests.
|
||||
Default to [True, False] if not specified. Note that for `graph` mode,
|
||||
run_eagerly value will only be False.
|
||||
|
||||
Returns:
|
||||
A list contains all the combinations to be used to generate test cases.
|
||||
"""
|
||||
if mode is None:
|
||||
mode = ['eager'] if tf2.enabled() else ['graph', 'eager']
|
||||
if run_eagerly is None:
|
||||
run_eagerly = [True, False]
|
||||
result = []
|
||||
if 'eager' in mode:
|
||||
result += combinations.combine(mode=['eager'], run_eagerly=run_eagerly)
|
||||
if 'graph' in mode:
|
||||
result += combinations.combine(mode=['graph'], run_eagerly=[False])
|
||||
return result
|
||||
|
||||
|
||||
def keras_model_type_combinations():
|
||||
return combinations.combine(model_type=KERAS_MODEL_TYPES)
|
||||
|
||||
|
||||
class KerasModeCombination(test_combinations.TestCombination):
|
||||
"""Combination for Keras test mode.
|
||||
|
||||
It by default includes v1_session, v2_eager and v2_tf_function.
|
||||
"""
|
||||
|
||||
def context_managers(self, kwargs):
|
||||
run_eagerly = kwargs.pop('run_eagerly', None)
|
||||
|
||||
if run_eagerly is not None:
|
||||
return [testing_utils.run_eagerly_scope(run_eagerly)]
|
||||
else:
|
||||
return []
|
||||
|
||||
def parameter_modifiers(self):
|
||||
return [test_combinations.OptionalParameter('run_eagerly')]
|
||||
|
||||
|
||||
class KerasModelTypeCombination(test_combinations.TestCombination):
|
||||
"""Combination for Keras model types when doing model test.
|
||||
|
||||
It by default includes 'functional', 'subclass', 'sequential'.
|
||||
|
||||
Various methods in `testing_utils` to get models will auto-generate a model
|
||||
of the currently active Keras model type. This allows unittests to confirm
|
||||
the equivalence between different Keras models.
|
||||
"""
|
||||
|
||||
def context_managers(self, kwargs):
|
||||
model_type = kwargs.pop('model_type', None)
|
||||
if model_type in KERAS_MODEL_TYPES:
|
||||
return [testing_utils.model_type_scope(model_type)]
|
||||
else:
|
||||
return []
|
||||
|
||||
def parameter_modifiers(self):
|
||||
return [test_combinations.OptionalParameter('model_type')]
|
||||
|
||||
|
||||
_defaults = combinations.generate.keywords['test_combinations']
|
||||
generate = functools.partial(
|
||||
combinations.generate,
|
||||
test_combinations=_defaults +
|
||||
(KerasModeCombination(), KerasModelTypeCombination()))
|
||||
combine = test_combinations.combine
|
||||
times = test_combinations.times
|
||||
NamedObject = test_combinations.NamedObject
|
178
tensorflow/python/keras/combinations_test.py
Normal file
178
tensorflow/python/keras/combinations_test.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Keras combinations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import models as keras_models
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class CombinationsTest(test.TestCase):
|
||||
|
||||
def test_run_all_keras_modes(self):
|
||||
test_params = []
|
||||
|
||||
class ExampleTest(parameterized.TestCase):
|
||||
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def testBody(self):
|
||||
mode = "eager" if context.executing_eagerly() else "graph"
|
||||
should_run_eagerly = testing_utils.should_run_eagerly()
|
||||
test_params.append((mode, should_run_eagerly))
|
||||
|
||||
e = ExampleTest()
|
||||
if not tf2.enabled():
|
||||
e.testBody_test_mode_graph_runeagerly_False()
|
||||
e.testBody_test_mode_eager_runeagerly_True()
|
||||
e.testBody_test_mode_eager_runeagerly_False()
|
||||
|
||||
if not tf2.enabled():
|
||||
self.assertLen(test_params, 3)
|
||||
self.assertAllEqual(test_params, [
|
||||
("graph", False),
|
||||
("eager", True),
|
||||
("eager", False),
|
||||
])
|
||||
|
||||
ts = unittest.makeSuite(ExampleTest)
|
||||
res = unittest.TestResult()
|
||||
ts.run(res)
|
||||
self.assertLen(test_params, 6)
|
||||
else:
|
||||
self.assertLen(test_params, 2)
|
||||
self.assertAllEqual(test_params, [
|
||||
("eager", True),
|
||||
("eager", False),
|
||||
])
|
||||
|
||||
ts = unittest.makeSuite(ExampleTest)
|
||||
res = unittest.TestResult()
|
||||
ts.run(res)
|
||||
self.assertLen(test_params, 4)
|
||||
|
||||
def test_generate_keras_mode_eager_only(self):
|
||||
result = combinations.keras_mode_combinations(mode=["eager"])
|
||||
self.assertLen(result, 2)
|
||||
self.assertEqual(result[0], {"mode": "eager", "run_eagerly": True})
|
||||
self.assertEqual(result[1], {"mode": "eager", "run_eagerly": False})
|
||||
|
||||
def test_generate_keras_mode_skip_run_eagerly(self):
|
||||
result = combinations.keras_mode_combinations(run_eagerly=[False])
|
||||
if tf2.enabled():
|
||||
self.assertLen(result, 1)
|
||||
self.assertEqual(result[0], {"mode": "eager", "run_eagerly": False})
|
||||
else:
|
||||
self.assertLen(result, 2)
|
||||
self.assertEqual(result[0], {"mode": "eager", "run_eagerly": False})
|
||||
self.assertEqual(result[1], {"mode": "graph", "run_eagerly": False})
|
||||
|
||||
def test_run_all_keras_model_types(self):
|
||||
model_types = []
|
||||
models = []
|
||||
|
||||
class ExampleTest(parameterized.TestCase):
|
||||
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
@combinations.generate(combinations.keras_model_type_combinations())
|
||||
def testBody(self):
|
||||
model_types.append(testing_utils.get_model_type())
|
||||
models.append(testing_utils.get_small_mlp(1, 4, input_dim=3))
|
||||
|
||||
e = ExampleTest()
|
||||
e.testBody_test_modeltype_functional()
|
||||
e.testBody_test_modeltype_subclass()
|
||||
e.testBody_test_modeltype_sequential()
|
||||
|
||||
self.assertLen(model_types, 3)
|
||||
self.assertAllEqual(model_types, [
|
||||
"functional",
|
||||
"subclass",
|
||||
"sequential"
|
||||
])
|
||||
|
||||
# Validate that the models are what they should be
|
||||
self.assertTrue(models[0]._is_graph_network)
|
||||
self.assertFalse(models[1]._is_graph_network)
|
||||
self.assertNotIsInstance(models[0], keras_models.Sequential)
|
||||
self.assertNotIsInstance(models[1], keras_models.Sequential)
|
||||
self.assertIsInstance(models[2], keras_models.Sequential)
|
||||
|
||||
ts = unittest.makeSuite(ExampleTest)
|
||||
res = unittest.TestResult()
|
||||
ts.run(res)
|
||||
|
||||
self.assertLen(model_types, 6)
|
||||
|
||||
def test_combine_combinations(self):
|
||||
test_cases = []
|
||||
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_mode_combinations(),
|
||||
combinations.keras_model_type_combinations()))
|
||||
class ExampleTest(parameterized.TestCase):
|
||||
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
@parameterized.named_parameters(dict(testcase_name="_arg",
|
||||
arg=True))
|
||||
def testBody(self, arg):
|
||||
del arg
|
||||
mode = "eager" if context.executing_eagerly() else "graph"
|
||||
should_run_eagerly = testing_utils.should_run_eagerly()
|
||||
test_cases.append((mode, should_run_eagerly,
|
||||
testing_utils.get_model_type()))
|
||||
|
||||
ts = unittest.makeSuite(ExampleTest)
|
||||
res = unittest.TestResult()
|
||||
ts.run(res)
|
||||
|
||||
expected_combinations = [
|
||||
("eager", False, "functional"),
|
||||
("eager", False, "sequential"),
|
||||
("eager", False, "subclass"),
|
||||
("eager", True, "functional"),
|
||||
("eager", True, "sequential"),
|
||||
("eager", True, "subclass"),
|
||||
]
|
||||
|
||||
if not tf2.enabled():
|
||||
expected_combinations.extend([
|
||||
("graph", False, "functional"),
|
||||
("graph", False, "sequential"),
|
||||
("graph", False, "subclass"),
|
||||
])
|
||||
|
||||
self.assertAllEqual(sorted(test_cases), expected_combinations)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -638,6 +638,7 @@ tf_py_test(
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
@ -18,20 +18,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class SimpleRNNLayerTest(keras_parameterized.TestCase):
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
class SimpleRNNLayerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_return_sequences_SimpleRNN(self):
|
||||
num_samples = 2
|
||||
@ -145,14 +146,14 @@ class SimpleRNNLayerTest(keras_parameterized.TestCase):
|
||||
bias_regularizer='l2',
|
||||
activity_regularizer='l1')
|
||||
layer.build((None, None, 2))
|
||||
self.assertEqual(len(layer.losses), 3)
|
||||
self.assertLen(layer.losses, 3)
|
||||
|
||||
x = keras.backend.variable(np.ones((2, 3, 2)))
|
||||
layer(x)
|
||||
if context.executing_eagerly():
|
||||
self.assertEqual(len(layer.losses), 4)
|
||||
self.assertLen(layer.losses, 4)
|
||||
else:
|
||||
self.assertEqual(len(layer.get_losses_for(x)), 1)
|
||||
self.assertLen(layer.get_losses_for(x), 1)
|
||||
|
||||
def test_statefulness_SimpleRNN(self):
|
||||
num_samples = 2
|
||||
|
@ -113,6 +113,7 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/eager:eager_pip",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//tensorflow/python/keras/layers/preprocessing:preprocessing_test_utils",
|
||||
"//tensorflow/python/keras/distribute:distribute_strategy_test_lib",
|
||||
"//tensorflow/python/keras/distribute:multi_worker_testing_utils",
|
||||
|
Loading…
x
Reference in New Issue
Block a user