STT-tensorflow/tensorflow/python/keras/combinations_test.py
Scott Zhu 2733583b75 Add test combinations for Keras test.
Wrt to the test annotations for eager/graph and others, TF API owner reach the consensus that combination framework is more generic and should be used for new API for creating parameterized tests.

This change also adds Keras specific combinations for existing keras_mode and model types. Updating simple_rnn_test as a show case. All the other keras tests will be updated later.

PiperOrigin-RevId: 298939205
Change-Id: I5b63b57717f79ae338cf81d8505897e95ab94b8f
2020-03-04 14:24:59 -08:00

179 lines
5.6 KiB
Python

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras combinations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from absl.testing import parameterized
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.keras import combinations
from tensorflow.python.keras import models as keras_models
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
class CombinationsTest(test.TestCase):
def test_run_all_keras_modes(self):
test_params = []
class ExampleTest(parameterized.TestCase):
def runTest(self):
pass
@combinations.generate(combinations.keras_mode_combinations())
def testBody(self):
mode = "eager" if context.executing_eagerly() else "graph"
should_run_eagerly = testing_utils.should_run_eagerly()
test_params.append((mode, should_run_eagerly))
e = ExampleTest()
if not tf2.enabled():
e.testBody_test_mode_graph_runeagerly_False()
e.testBody_test_mode_eager_runeagerly_True()
e.testBody_test_mode_eager_runeagerly_False()
if not tf2.enabled():
self.assertLen(test_params, 3)
self.assertAllEqual(test_params, [
("graph", False),
("eager", True),
("eager", False),
])
ts = unittest.makeSuite(ExampleTest)
res = unittest.TestResult()
ts.run(res)
self.assertLen(test_params, 6)
else:
self.assertLen(test_params, 2)
self.assertAllEqual(test_params, [
("eager", True),
("eager", False),
])
ts = unittest.makeSuite(ExampleTest)
res = unittest.TestResult()
ts.run(res)
self.assertLen(test_params, 4)
def test_generate_keras_mode_eager_only(self):
result = combinations.keras_mode_combinations(mode=["eager"])
self.assertLen(result, 2)
self.assertEqual(result[0], {"mode": "eager", "run_eagerly": True})
self.assertEqual(result[1], {"mode": "eager", "run_eagerly": False})
def test_generate_keras_mode_skip_run_eagerly(self):
result = combinations.keras_mode_combinations(run_eagerly=[False])
if tf2.enabled():
self.assertLen(result, 1)
self.assertEqual(result[0], {"mode": "eager", "run_eagerly": False})
else:
self.assertLen(result, 2)
self.assertEqual(result[0], {"mode": "eager", "run_eagerly": False})
self.assertEqual(result[1], {"mode": "graph", "run_eagerly": False})
def test_run_all_keras_model_types(self):
model_types = []
models = []
class ExampleTest(parameterized.TestCase):
def runTest(self):
pass
@combinations.generate(combinations.keras_model_type_combinations())
def testBody(self):
model_types.append(testing_utils.get_model_type())
models.append(testing_utils.get_small_mlp(1, 4, input_dim=3))
e = ExampleTest()
e.testBody_test_modeltype_functional()
e.testBody_test_modeltype_subclass()
e.testBody_test_modeltype_sequential()
self.assertLen(model_types, 3)
self.assertAllEqual(model_types, [
"functional",
"subclass",
"sequential"
])
# Validate that the models are what they should be
self.assertTrue(models[0]._is_graph_network)
self.assertFalse(models[1]._is_graph_network)
self.assertNotIsInstance(models[0], keras_models.Sequential)
self.assertNotIsInstance(models[1], keras_models.Sequential)
self.assertIsInstance(models[2], keras_models.Sequential)
ts = unittest.makeSuite(ExampleTest)
res = unittest.TestResult()
ts.run(res)
self.assertLen(model_types, 6)
def test_combine_combinations(self):
test_cases = []
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(),
combinations.keras_model_type_combinations()))
class ExampleTest(parameterized.TestCase):
def runTest(self):
pass
@parameterized.named_parameters(dict(testcase_name="_arg",
arg=True))
def testBody(self, arg):
del arg
mode = "eager" if context.executing_eagerly() else "graph"
should_run_eagerly = testing_utils.should_run_eagerly()
test_cases.append((mode, should_run_eagerly,
testing_utils.get_model_type()))
ts = unittest.makeSuite(ExampleTest)
res = unittest.TestResult()
ts.run(res)
expected_combinations = [
("eager", False, "functional"),
("eager", False, "sequential"),
("eager", False, "subclass"),
("eager", True, "functional"),
("eager", True, "sequential"),
("eager", True, "subclass"),
]
if not tf2.enabled():
expected_combinations.extend([
("graph", False, "functional"),
("graph", False, "sequential"),
("graph", False, "subclass"),
])
self.assertAllEqual(sorted(test_cases), expected_combinations)
if __name__ == "__main__":
test.main()