From d8551244b8aa8a2724e8e9306a528b7b95f3f128 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Fri, 6 Sep 2019 10:58:08 -0700 Subject: [PATCH] Add test_util decorator that makes it easier to check forward compatibility horizons. PiperOrigin-RevId: 267631954 --- tensorflow/python/framework/test_util.py | 36 +++++++++++++++++++ tensorflow/python/framework/test_util_test.py | 18 ++++++++++ 2 files changed, 54 insertions(+) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 01e257cfdd1..fa3699c026d 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -54,6 +54,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python import tf2 from tensorflow.python.client import device_lib from tensorflow.python.client import session +from tensorflow.python.compat.compat import forward_compatibility_horizon from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import tape @@ -1396,6 +1397,41 @@ def run_cuda_only(func=None): return decorator +def with_forward_compatibility_horizons(*horizons): + """Executes the decorated test with the specified forward-compat horizons. + + Args: + *horizons: A list of (year, month, day) tuples. If the list includes + `None`, then the test will also be run with no forward-compatibility + horizon set. + + Returns: + A decorator that will execute the test with the specified horizons. + """ + if not horizons: + raise ValueError("Expected at least one horizon.") + for horizon in horizons: + if not ((horizon is None) or + (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))): + raise ValueError("Bad horizon value: %r" % horizon) + + def decorator(f): + if tf_inspect.isclass(f): + raise ValueError("`with_forward_compatibility_horizons` only " + "supports test methods.") + def decorated(self, *args, **kwargs): + for horizon in horizons: + if horizon is None: + f(self, *args, **kwargs) + else: + (year, month, day) = horizon + with forward_compatibility_horizon(year, month, day): + f(self, *args, **kwargs) + return decorated + + return decorator + + @tf_export("test.is_gpu_available") def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): """Returns whether TensorFlow can access a GPU. diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 6657d887592..6278bb4e270 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -31,6 +31,7 @@ from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -768,6 +769,23 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertTrue(test_object.graph_mode_tested) self.assertTrue(test_object.inside_function_tested) + def test_with_forward_compatibility_horizons(self): + + tested_codepaths = set() + def some_function_with_forward_compat_behavior(): + if compat.forward_compatible(2050, 1, 1): + tested_codepaths.add("future") + else: + tested_codepaths.add("present") + + @test_util.with_forward_compatibility_horizons(None, [2051, 1, 1]) + def some_test(self): + del self # unused + some_function_with_forward_compat_behavior() + + some_test(None) + self.assertEqual(tested_codepaths, set(["present", "future"])) + # Its own test case to reproduce variable sharing issues which only pop up when # setUp() is overridden and super() is not called.