Utility to run tests inside tf.function and eager.
Relies on being able to run the assert* test methods inside a py_func to run them inside the graph, so there's no need for self.evaluate or similar methods which create a graph/eager hybrid programming model. PiperOrigin-RevId: 224575790
This commit is contained in:
parent
71ea120a1b
commit
093f036323
@ -428,20 +428,21 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(self.evaluate(value), 2.0)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.also_run_as_tf_function
|
||||
def testInitScopeTensorInitializationInFunction(self):
|
||||
|
||||
@def_function.function
|
||||
def tensor_init():
|
||||
with ops.init_scope():
|
||||
const = constant_op.constant(2.0)
|
||||
# Note: this variable bypasses tf.function's variable creation
|
||||
# requirements by bypassing variable_creator_scope by using
|
||||
# ResourceVariable instead of Variable.
|
||||
self.v = resource_variable_ops.ResourceVariable(const)
|
||||
return self.v.read_value()
|
||||
|
||||
value = tensor_init()
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(self.evaluate(value), 2.0)
|
||||
self.assertAllEqual(value, 2.0)
|
||||
|
||||
def testDefunShapeInferenceWithCapturedResourceVariable(self):
|
||||
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
|
||||
|
@ -54,6 +54,7 @@ from tensorflow.python import tf2
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -67,6 +68,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
@ -76,6 +78,7 @@ from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import memory
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.protobuf import compare
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -1009,6 +1012,58 @@ def run_in_graph_and_eager_modes(func=None,
|
||||
return decorator
|
||||
|
||||
|
||||
def py_func_if_in_function(f):
|
||||
|
||||
def decorated(*args, **kwds):
|
||||
if not ops.get_default_graph()._building_function:
|
||||
return f(*args, **kwds)
|
||||
|
||||
tensor_args, tensor_indices = zip(
|
||||
*[(x, i) for i, x in enumerate(args)
|
||||
if isinstance(x, (ops.Tensor, variables.Variable))])
|
||||
|
||||
def inner_f(*inner_tensor_args):
|
||||
my_args = list(args)
|
||||
for i, n in zip(tensor_indices, inner_tensor_args):
|
||||
my_args[i] = n
|
||||
return f(*my_args, **kwds)
|
||||
|
||||
return script_ops.py_func(inner_f, tensor_args, [])
|
||||
|
||||
return tf_decorator.make_decorator(f, decorated)
|
||||
|
||||
|
||||
def also_run_as_tf_function(f):
|
||||
"""Runs the decorated test twice--once as is, once inside a tf.function.
|
||||
|
||||
This allows you to run a test both in eager execution and inside a
|
||||
tf.function, exercising the two execution modes supported in tf 2.0. The test
|
||||
assertions are automatically done inside tf.py_funcs, and tf.function ensures
|
||||
that they run in the proper order and with the proper side effects.
|
||||
|
||||
Currently variable creation is not supported in tests annotated with this
|
||||
decorator since it's tricky to ensure the variable doesn't get repeatedly
|
||||
created when retracing the tf.function.
|
||||
|
||||
Args:
|
||||
f: the test method to be decorated
|
||||
|
||||
Returns:
|
||||
The decorated test method, which will run both in eager and inside a
|
||||
tf.function.
|
||||
"""
|
||||
|
||||
def decorated(*args, **kwds):
|
||||
with context.eager_mode():
|
||||
# Running in eager mode
|
||||
f(*args, **kwds)
|
||||
|
||||
defun_f = def_function.function(f)
|
||||
defun_f(*args, **kwds)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def run_deprecated_v1(func=None):
|
||||
"""Execute the decorated test in graph mode.
|
||||
|
||||
@ -1783,8 +1838,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
return ret
|
||||
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
@py_func_if_in_function
|
||||
def assertNear(self, f1, f2, err, msg=None):
|
||||
"""Asserts that two floats are near each other.
|
||||
|
||||
@ -1803,6 +1858,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
"%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
|
||||
if msg is not None else ""))
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertArrayNear(self, farray1, farray2, err, msg=None):
|
||||
"""Asserts that two float arrays are near each other.
|
||||
|
||||
@ -1822,6 +1878,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
def _NDArrayNear(self, ndarray1, ndarray2, err):
|
||||
return np.linalg.norm(ndarray1 - ndarray2) < err
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
|
||||
"""Asserts that two numpy arrays have near values.
|
||||
|
||||
@ -1959,6 +2016,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
|
||||
raise
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
|
||||
"""Asserts that two structures of numpy arrays or Tensors, have near values.
|
||||
|
||||
@ -1984,6 +2042,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
"""
|
||||
self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertAllCloseAccordingToType(self,
|
||||
a,
|
||||
b,
|
||||
@ -2031,6 +2090,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
|
||||
self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertNotAllClose(self, a, b, **kwargs):
|
||||
"""Assert that two numpy arrays, or or Tensors, do not have near values.
|
||||
|
||||
@ -2049,6 +2109,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
return
|
||||
raise AssertionError("The two values are close at all elements")
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertAllEqual(self, a, b, msg=None):
|
||||
"""Asserts that two numpy arrays or Tensors have the same values.
|
||||
|
||||
@ -2091,6 +2152,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
msgs.append("not equal rhs = {}".format(y))
|
||||
np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertAllGreater(self, a, comparison_target):
|
||||
"""Assert element values are all greater than a target value.
|
||||
|
||||
@ -2102,6 +2164,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
a = self._GetNdArray(a)
|
||||
self.assertGreater(np.min(a), comparison_target)
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertAllLess(self, a, comparison_target):
|
||||
"""Assert element values are all less than a target value.
|
||||
|
||||
@ -2113,6 +2176,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
a = self._GetNdArray(a)
|
||||
self.assertLess(np.max(a), comparison_target)
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertAllGreaterEqual(self, a, comparison_target):
|
||||
"""Assert element values are all greater than or equal to a target value.
|
||||
|
||||
@ -2124,6 +2188,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
a = self._GetNdArray(a)
|
||||
self.assertGreaterEqual(np.min(a), comparison_target)
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertAllLessEqual(self, a, comparison_target):
|
||||
"""Assert element values are all less than or equal to a target value.
|
||||
|
||||
@ -2166,6 +2231,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
lines.append(prefix + "...")
|
||||
return lines
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertAllInRange(self,
|
||||
target,
|
||||
lower_bound,
|
||||
@ -2224,6 +2290,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
"Subscript(s) and value(s) of the offending elements:\n" +
|
||||
"\n".join(self._format_subscripts(violation_subscripts, target)))
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertAllInSet(self, target, expected_set):
|
||||
"""Assert that elements of a Tensor are all in a given closed set.
|
||||
|
||||
@ -2245,6 +2312,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
raise AssertionError("%d unique element(s) are not in the set %s: %s" %
|
||||
(np.size(diff), expected_set, diff))
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertDTypeEqual(self, target, expected_dtype):
|
||||
"""Assert ndarray data type is equal to expected.
|
||||
|
||||
|
@ -237,7 +237,8 @@ class VariableScopeTest(test.TestCase):
|
||||
_ = d2(x)
|
||||
self.assertEqual(len(d2.variables), 2)
|
||||
v3, v4 = d2.variables
|
||||
self.assertAllEqual([v1, v2], [v3, v4])
|
||||
self.assertEqual(v1, v3)
|
||||
self.assertEqual(v2, v4)
|
||||
f()
|
||||
|
||||
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
||||
@ -1684,7 +1685,7 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
|
||||
with variable_scope.variable_creator_scope(creator_b):
|
||||
variable_scope.variable(1.0, name="one_name")
|
||||
|
||||
self.assertAllEqual(variable_names, ["forced_name"])
|
||||
self.assertEqual(variable_names[0], "forced_name")
|
||||
|
||||
called = [False]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user