Add test_util decorator that makes it easier to check forward compatibility horizons.
PiperOrigin-RevId: 267631954
This commit is contained in:
parent
4bed890564
commit
d8551244b8
@ -54,6 +54,7 @@ from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat.compat import forward_compatibility_horizon
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import tape
|
||||
@ -1396,6 +1397,41 @@ def run_cuda_only(func=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def with_forward_compatibility_horizons(*horizons):
|
||||
"""Executes the decorated test with the specified forward-compat horizons.
|
||||
|
||||
Args:
|
||||
*horizons: A list of (year, month, day) tuples. If the list includes
|
||||
`None`, then the test will also be run with no forward-compatibility
|
||||
horizon set.
|
||||
|
||||
Returns:
|
||||
A decorator that will execute the test with the specified horizons.
|
||||
"""
|
||||
if not horizons:
|
||||
raise ValueError("Expected at least one horizon.")
|
||||
for horizon in horizons:
|
||||
if not ((horizon is None) or
|
||||
(len(horizon) == 3 and all(isinstance(x, int) for x in horizon))):
|
||||
raise ValueError("Bad horizon value: %r" % horizon)
|
||||
|
||||
def decorator(f):
|
||||
if tf_inspect.isclass(f):
|
||||
raise ValueError("`with_forward_compatibility_horizons` only "
|
||||
"supports test methods.")
|
||||
def decorated(self, *args, **kwargs):
|
||||
for horizon in horizons:
|
||||
if horizon is None:
|
||||
f(self, *args, **kwargs)
|
||||
else:
|
||||
(year, month, day) = horizon
|
||||
with forward_compatibility_horizon(year, month, day):
|
||||
f(self, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@tf_export("test.is_gpu_available")
|
||||
def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
|
||||
"""Returns whether TensorFlow can access a GPU.
|
||||
|
@ -31,6 +31,7 @@ from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -768,6 +769,23 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertTrue(test_object.graph_mode_tested)
|
||||
self.assertTrue(test_object.inside_function_tested)
|
||||
|
||||
def test_with_forward_compatibility_horizons(self):
|
||||
|
||||
tested_codepaths = set()
|
||||
def some_function_with_forward_compat_behavior():
|
||||
if compat.forward_compatible(2050, 1, 1):
|
||||
tested_codepaths.add("future")
|
||||
else:
|
||||
tested_codepaths.add("present")
|
||||
|
||||
@test_util.with_forward_compatibility_horizons(None, [2051, 1, 1])
|
||||
def some_test(self):
|
||||
del self # unused
|
||||
some_function_with_forward_compat_behavior()
|
||||
|
||||
some_test(None)
|
||||
self.assertEqual(tested_codepaths, set(["present", "future"]))
|
||||
|
||||
|
||||
# Its own test case to reproduce variable sharing issues which only pop up when
|
||||
# setUp() is overridden and super() is not called.
|
||||
|
Loading…
Reference in New Issue
Block a user