Add convenient methods to write test combinations with and without tf.function

Instead of using def_function, in test we can parameterize with these two
objects to test both tf.function and eager execution.

PiperOrigin-RevId: 351420316
Change-Id: I037d1678ca843f6df88694981efd4519c2947cd3
This commit is contained in:
Ran Chen 2021-01-12 12:06:02 -08:00 committed by TensorFlower Gardener
parent 80498bb94f
commit 3b5bb5a706
3 changed files with 46 additions and 0 deletions

View File

@ -808,10 +808,13 @@ py_library(
"//tensorflow/python:framework_combinations", "//tensorflow/python:framework_combinations",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_combinations_lib", "//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//tensorflow/python:session", "//tensorflow/python:session",
"//tensorflow/python:tf_decorator", "//tensorflow/python:tf_decorator",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/util:tf_export",
"@six_archive//:six", "@six_archive//:six",
], ],
) )
@ -826,6 +829,7 @@ py_test(
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_combinations", "//tensorflow/python:framework_combinations",
"//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py", "//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py",
"//tensorflow/python/eager:context",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
) )

View File

@ -37,6 +37,7 @@ from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import combinations as framework_combinations from tensorflow.python.framework import combinations as framework_combinations
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations_lib from tensorflow.python.framework import test_combinations as combinations_lib
@ -298,6 +299,24 @@ class NamedDistribution(object):
return self._name return self._name
# This is to allow adding combinations that runs a function both as a
# tf.function and eagerly.
#
# @combinations.generate(
# combinations.combine(
# tf_function = [combinations.tf_function, combinations.no_tf_function]
# )
# )
# def testXXX(tf_function):
# @tf_function
# def foo():
# tf.add(1., 1.)
#
# foo()
tf_function = combinations_lib.NamedObject("TfFunction", def_function.function)
no_tf_function = combinations_lib.NamedObject("NoTfFunction", lambda f: f)
def concat(*combined): def concat(*combined):
"""Concats combinations.""" """Concats combinations."""
result = [] result = []

View File

@ -27,6 +27,7 @@ from absl.testing import parameterized
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import test_util from tensorflow.python.distribute import test_util
from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations as framework_combinations from tensorflow.python.framework import combinations as framework_combinations
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -174,5 +175,27 @@ class CombinationsOnClassMultiWorkerExpectedFailureTest(test.TestCase,
self.assertIsNone(resolver.task_id) self.assertIsNone(resolver.task_id)
class TfFunctionTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
tf_function_1=combinations.tf_function,
tf_function_2=combinations.no_tf_function,
mode="eager",
))
def testFunc(self, tf_function_1, tf_function_2):
@tf_function_1
def foo():
self.assertFalse(context.executing_eagerly())
@tf_function_2
def bar():
self.assertTrue(context.executing_eagerly())
foo()
bar()
if __name__ == "__main__": if __name__ == "__main__":
test_util.main() test_util.main()