Make test_run_in_graph_and_eager work with test combinations
PiperOrigin-RevId: 295887061 Change-Id: I83ca68a1e01ad124cc25dff071affdc8c6413b55
This commit is contained in:
parent
ae7a428bfa
commit
eebf50dd9e
@ -2531,6 +2531,7 @@ tf_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":control_flow_ops",
|
":control_flow_ops",
|
||||||
":errors",
|
":errors",
|
||||||
|
":framework_combinations",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
":framework_test_lib",
|
":framework_test_lib",
|
||||||
":platform_test",
|
":platform_test",
|
||||||
|
@ -1136,7 +1136,7 @@ def run_in_graph_and_eager_modes(func=None,
|
|||||||
run_eagerly(self, **kwargs)
|
run_eagerly(self, **kwargs)
|
||||||
ops.dismantle_graph(graph_for_eager_test)
|
ops.dismantle_graph(graph_for_eager_test)
|
||||||
|
|
||||||
return decorated
|
return tf_decorator.make_decorator(f, decorated)
|
||||||
|
|
||||||
if func is not None:
|
if func is not None:
|
||||||
return decorator(func)
|
return decorator(func)
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.core.framework import graph_pb2
|
|||||||
from tensorflow.core.protobuf import meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
from tensorflow.python.compat import compat
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -742,6 +743,11 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
def test_run_in_graph_and_eager_works_with_parameterized_keyword(self, arg):
|
def test_run_in_graph_and_eager_works_with_parameterized_keyword(self, arg):
|
||||||
self.assertEqual(arg, True)
|
self.assertEqual(arg, True)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(arg=True))
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_run_in_graph_and_eager_works_with_combinations(self, arg):
|
||||||
|
self.assertEqual(arg, True)
|
||||||
|
|
||||||
def test_build_as_function_and_v1_graph(self):
|
def test_build_as_function_and_v1_graph(self):
|
||||||
|
|
||||||
class GraphModeAndFunctionTest(parameterized.TestCase):
|
class GraphModeAndFunctionTest(parameterized.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user