Customizations are split into tf.distribute.*-specific and general TF. PiperOrigin-RevId: 257903643
157 lines
5.2 KiB
Python
157 lines
5.2 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests generating test combinations."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from collections import OrderedDict
|
|
|
|
from absl.testing import parameterized
|
|
|
|
from tensorflow.python.framework import test_combinations as combinations
|
|
from tensorflow.python.eager import test
|
|
|
|
|
|
class TestingCombinationsTest(test.TestCase):
|
|
|
|
def test_combine(self):
|
|
self.assertEqual([{
|
|
"a": 1,
|
|
"b": 2
|
|
}, {
|
|
"a": 1,
|
|
"b": 3
|
|
}, {
|
|
"a": 2,
|
|
"b": 2
|
|
}, {
|
|
"a": 2,
|
|
"b": 3
|
|
}], combinations.combine(a=[1, 2], b=[2, 3]))
|
|
|
|
def test_arguments_sorted(self):
|
|
self.assertEqual([
|
|
OrderedDict([("aa", 1), ("ab", 2)]),
|
|
OrderedDict([("aa", 1), ("ab", 3)]),
|
|
OrderedDict([("aa", 2), ("ab", 2)]),
|
|
OrderedDict([("aa", 2), ("ab", 3)])
|
|
], combinations.combine(ab=[2, 3], aa=[1, 2]))
|
|
|
|
def test_combine_single_parameter(self):
|
|
self.assertEqual([{
|
|
"a": 1,
|
|
"b": 2
|
|
}, {
|
|
"a": 2,
|
|
"b": 2
|
|
}], combinations.combine(a=[1, 2], b=2))
|
|
|
|
def test_add(self):
|
|
self.assertEqual(
|
|
[{
|
|
"a": 1
|
|
}, {
|
|
"a": 2
|
|
}, {
|
|
"b": 2
|
|
}, {
|
|
"b": 3
|
|
}],
|
|
combinations.combine(a=[1, 2]) + combinations.combine(b=[2, 3]))
|
|
|
|
def test_times(self):
|
|
c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
|
|
c2 = combinations.combine(mode=["eager"], loss=["callable"])
|
|
c3 = combinations.combine(distribution=["d1", "d2"])
|
|
c4 = combinations.times(c3, c1 + c2)
|
|
self.assertEqual([
|
|
OrderedDict([("distribution", "d1"), ("loss", "callable"),
|
|
("mode", "graph")]),
|
|
OrderedDict([("distribution", "d1"), ("loss", "tensor"),
|
|
("mode", "graph")]),
|
|
OrderedDict([("distribution", "d1"), ("loss", "callable"),
|
|
("mode", "eager")]),
|
|
OrderedDict([("distribution", "d2"), ("loss", "callable"),
|
|
("mode", "graph")]),
|
|
OrderedDict([("distribution", "d2"), ("loss", "tensor"),
|
|
("mode", "graph")]),
|
|
OrderedDict([("distribution", "d2"), ("loss", "callable"),
|
|
("mode", "eager")])
|
|
], c4)
|
|
|
|
def test_times_variable_arguments(self):
|
|
c1 = combinations.combine(mode=["graph", "eager"])
|
|
c2 = combinations.combine(optimizer=["adam", "gd"])
|
|
c3 = combinations.combine(distribution=["d1", "d2"])
|
|
c4 = combinations.times(c3, c1, c2)
|
|
self.assertEqual([
|
|
OrderedDict([("distribution", "d1"), ("mode", "graph"),
|
|
("optimizer", "adam")]),
|
|
OrderedDict([("distribution", "d1"), ("mode", "graph"),
|
|
("optimizer", "gd")]),
|
|
OrderedDict([("distribution", "d1"), ("mode", "eager"),
|
|
("optimizer", "adam")]),
|
|
OrderedDict([("distribution", "d1"), ("mode", "eager"),
|
|
("optimizer", "gd")]),
|
|
OrderedDict([("distribution", "d2"), ("mode", "graph"),
|
|
("optimizer", "adam")]),
|
|
OrderedDict([("distribution", "d2"), ("mode", "graph"),
|
|
("optimizer", "gd")]),
|
|
OrderedDict([("distribution", "d2"), ("mode", "eager"),
|
|
("optimizer", "adam")]),
|
|
OrderedDict([("distribution", "d2"), ("mode", "eager"),
|
|
("optimizer", "gd")])
|
|
], c4)
|
|
self.assertEqual(
|
|
combinations.combine(
|
|
mode=["graph", "eager"],
|
|
optimizer=["adam", "gd"],
|
|
distribution=["d1", "d2"]), c4)
|
|
|
|
def test_overlapping_keys(self):
|
|
c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
|
|
c2 = combinations.combine(mode=["eager"], loss=["callable"])
|
|
with self.assertRaisesRegexp(ValueError, ".*Keys.+overlap.+"):
|
|
_ = combinations.times(c1, c2)
|
|
|
|
|
|
@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1]))
|
|
class CombineTheTestSuite(parameterized.TestCase):
|
|
|
|
def test_add_things(self, a, b, c):
|
|
self.assertLessEqual(3, a + b + c)
|
|
self.assertLessEqual(a + b + c, 5)
|
|
|
|
def test_add_things_one_more(self, a, b, c):
|
|
self.assertLessEqual(3, a + b + c)
|
|
self.assertLessEqual(a + b + c, 5)
|
|
|
|
def not_a_test(self, a=0, b=0, c=0):
|
|
del a, b, c
|
|
self.fail()
|
|
|
|
def _test_but_private(self, a=0, b=0, c=0):
|
|
del a, b, c
|
|
self.fail()
|
|
|
|
# Check that nothing funny happens to a non-callable that starts with "_test".
|
|
test_member = 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|