STT-tensorflow/tensorflow/python/framework/test_combinations_test.py
Igor Saprykin 42b8511f63 Move combinations.py closer to TF's test_util.py.
Customizations are split into tf.distribute.*-specific and general TF.

PiperOrigin-RevId: 257903643
2019-07-12 18:19:51 -07:00

157 lines
5.2 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests generating test combinations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from absl.testing import parameterized
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.eager import test
class TestingCombinationsTest(test.TestCase):
def test_combine(self):
self.assertEqual([{
"a": 1,
"b": 2
}, {
"a": 1,
"b": 3
}, {
"a": 2,
"b": 2
}, {
"a": 2,
"b": 3
}], combinations.combine(a=[1, 2], b=[2, 3]))
def test_arguments_sorted(self):
self.assertEqual([
OrderedDict([("aa", 1), ("ab", 2)]),
OrderedDict([("aa", 1), ("ab", 3)]),
OrderedDict([("aa", 2), ("ab", 2)]),
OrderedDict([("aa", 2), ("ab", 3)])
], combinations.combine(ab=[2, 3], aa=[1, 2]))
def test_combine_single_parameter(self):
self.assertEqual([{
"a": 1,
"b": 2
}, {
"a": 2,
"b": 2
}], combinations.combine(a=[1, 2], b=2))
def test_add(self):
self.assertEqual(
[{
"a": 1
}, {
"a": 2
}, {
"b": 2
}, {
"b": 3
}],
combinations.combine(a=[1, 2]) + combinations.combine(b=[2, 3]))
def test_times(self):
c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
c2 = combinations.combine(mode=["eager"], loss=["callable"])
c3 = combinations.combine(distribution=["d1", "d2"])
c4 = combinations.times(c3, c1 + c2)
self.assertEqual([
OrderedDict([("distribution", "d1"), ("loss", "callable"),
("mode", "graph")]),
OrderedDict([("distribution", "d1"), ("loss", "tensor"),
("mode", "graph")]),
OrderedDict([("distribution", "d1"), ("loss", "callable"),
("mode", "eager")]),
OrderedDict([("distribution", "d2"), ("loss", "callable"),
("mode", "graph")]),
OrderedDict([("distribution", "d2"), ("loss", "tensor"),
("mode", "graph")]),
OrderedDict([("distribution", "d2"), ("loss", "callable"),
("mode", "eager")])
], c4)
def test_times_variable_arguments(self):
c1 = combinations.combine(mode=["graph", "eager"])
c2 = combinations.combine(optimizer=["adam", "gd"])
c3 = combinations.combine(distribution=["d1", "d2"])
c4 = combinations.times(c3, c1, c2)
self.assertEqual([
OrderedDict([("distribution", "d1"), ("mode", "graph"),
("optimizer", "adam")]),
OrderedDict([("distribution", "d1"), ("mode", "graph"),
("optimizer", "gd")]),
OrderedDict([("distribution", "d1"), ("mode", "eager"),
("optimizer", "adam")]),
OrderedDict([("distribution", "d1"), ("mode", "eager"),
("optimizer", "gd")]),
OrderedDict([("distribution", "d2"), ("mode", "graph"),
("optimizer", "adam")]),
OrderedDict([("distribution", "d2"), ("mode", "graph"),
("optimizer", "gd")]),
OrderedDict([("distribution", "d2"), ("mode", "eager"),
("optimizer", "adam")]),
OrderedDict([("distribution", "d2"), ("mode", "eager"),
("optimizer", "gd")])
], c4)
self.assertEqual(
combinations.combine(
mode=["graph", "eager"],
optimizer=["adam", "gd"],
distribution=["d1", "d2"]), c4)
def test_overlapping_keys(self):
c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
c2 = combinations.combine(mode=["eager"], loss=["callable"])
with self.assertRaisesRegexp(ValueError, ".*Keys.+overlap.+"):
_ = combinations.times(c1, c2)
@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1]))
class CombineTheTestSuite(parameterized.TestCase):
def test_add_things(self, a, b, c):
self.assertLessEqual(3, a + b + c)
self.assertLessEqual(a + b + c, 5)
def test_add_things_one_more(self, a, b, c):
self.assertLessEqual(3, a + b + c)
self.assertLessEqual(a + b + c, 5)
def not_a_test(self, a=0, b=0, c=0):
del a, b, c
self.fail()
def _test_but_private(self, a=0, b=0, c=0):
del a, b, c
self.fail()
# Check that nothing funny happens to a non-callable that starts with "_test".
test_member = 0
if __name__ == "__main__":
test.main()