Add convenient methods to write test combinations with and without tf.function
Instead of using def_function, in test we can parameterize with these two objects to test both tf.function and eager execution. PiperOrigin-RevId: 351420316 Change-Id: I037d1678ca843f6df88694981efd4519c2947cd3
This commit is contained in:
parent
80498bb94f
commit
3b5bb5a706
@ -808,10 +808,13 @@ py_library(
|
|||||||
"//tensorflow/python:framework_combinations",
|
"//tensorflow/python:framework_combinations",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:framework_test_combinations_lib",
|
"//tensorflow/python:framework_test_combinations_lib",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:session",
|
"//tensorflow/python:session",
|
||||||
"//tensorflow/python:tf_decorator",
|
"//tensorflow/python:tf_decorator",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
|
"//tensorflow/python/util:tf_export",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -826,6 +829,7 @@ py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_combinations",
|
"//tensorflow/python:framework_combinations",
|
||||||
"//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py",
|
"//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py",
|
||||||
|
"//tensorflow/python/eager:context",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -37,6 +37,7 @@ from tensorflow.python.distribute import distribute_lib
|
|||||||
from tensorflow.python.distribute import multi_process_runner
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import combinations as framework_combinations
|
from tensorflow.python.framework import combinations as framework_combinations
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_combinations as combinations_lib
|
from tensorflow.python.framework import test_combinations as combinations_lib
|
||||||
@ -298,6 +299,24 @@ class NamedDistribution(object):
|
|||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
|
|
||||||
|
# This is to allow adding combinations that runs a function both as a
|
||||||
|
# tf.function and eagerly.
|
||||||
|
#
|
||||||
|
# @combinations.generate(
|
||||||
|
# combinations.combine(
|
||||||
|
# tf_function = [combinations.tf_function, combinations.no_tf_function]
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# def testXXX(tf_function):
|
||||||
|
# @tf_function
|
||||||
|
# def foo():
|
||||||
|
# tf.add(1., 1.)
|
||||||
|
#
|
||||||
|
# foo()
|
||||||
|
tf_function = combinations_lib.NamedObject("TfFunction", def_function.function)
|
||||||
|
no_tf_function = combinations_lib.NamedObject("NoTfFunction", lambda f: f)
|
||||||
|
|
||||||
|
|
||||||
def concat(*combined):
|
def concat(*combined):
|
||||||
"""Concats combinations."""
|
"""Concats combinations."""
|
||||||
result = []
|
result = []
|
||||||
|
@ -27,6 +27,7 @@ from absl.testing import parameterized
|
|||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
from tensorflow.python.distribute import test_util
|
from tensorflow.python.distribute import test_util
|
||||||
from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
|
from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import combinations as framework_combinations
|
from tensorflow.python.framework import combinations as framework_combinations
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -174,5 +175,27 @@ class CombinationsOnClassMultiWorkerExpectedFailureTest(test.TestCase,
|
|||||||
self.assertIsNone(resolver.task_id)
|
self.assertIsNone(resolver.task_id)
|
||||||
|
|
||||||
|
|
||||||
|
class TfFunctionTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
tf_function_1=combinations.tf_function,
|
||||||
|
tf_function_2=combinations.no_tf_function,
|
||||||
|
mode="eager",
|
||||||
|
))
|
||||||
|
def testFunc(self, tf_function_1, tf_function_2):
|
||||||
|
|
||||||
|
@tf_function_1
|
||||||
|
def foo():
|
||||||
|
self.assertFalse(context.executing_eagerly())
|
||||||
|
|
||||||
|
@tf_function_2
|
||||||
|
def bar():
|
||||||
|
self.assertTrue(context.executing_eagerly())
|
||||||
|
|
||||||
|
foo()
|
||||||
|
bar()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_util.main()
|
test_util.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user