Remove default context for EagerGraphCombination.
It will incorrectly generate an extra graph context in following example: @combinations.generate(combinations.combine(mode=['eager'])) class TestClass(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine(foo=[1, 2])) def test_foo(self, foo): self.assertTrue(context.executing_eagerly()) Note that the combinations.generate for test_foo will silently add a graph context to the test case, since the default value for EagerGraphCombination is graph. This will override the eager context from the test class, and cause the assertion to fail. Similar error will also raise for following case: @combinations.generate(combinations.combine(foo=[1, 2])) class TestClass(test.TestCase, parameterized.TestCase): def test_foo(self, foo): self.assertTrue(context.executing_eagerly()) if __name__ == '__main__': ops.enable_eager_execution() test.main() Note that ops.enable_eager_execution() should force all the test case to run under eager context, but the silently added graph context is overriding the context value. PiperOrigin-RevId: 300409634 Change-Id: If379db35a980193a49b5d74910e67be13e9af30c
This commit is contained in:
parent
37e7693c78
commit
15845c945e
@ -31,16 +31,17 @@ from tensorflow.python.framework import test_combinations
|
||||
|
||||
|
||||
class EagerGraphCombination(test_combinations.TestCombination):
|
||||
"""Run the test in Graph or Eager mode. Graph is the default.
|
||||
"""Run the test in Graph or Eager mode.
|
||||
|
||||
The optional `mode` parameter controls the test's execution mode. Its
|
||||
accepted values are "graph" or "eager" literals.
|
||||
"""
|
||||
|
||||
def context_managers(self, kwargs):
|
||||
# TODO(isaprykin): Switch the default to eager.
|
||||
mode = kwargs.pop("mode", "graph")
|
||||
if mode == "eager":
|
||||
mode = kwargs.pop("mode", None)
|
||||
if mode is None:
|
||||
return []
|
||||
elif mode == "eager":
|
||||
return [context.eager_mode()]
|
||||
elif mode == "graph":
|
||||
return [ops.Graph().as_default(), context.graph_mode()]
|
||||
|
Loading…
Reference in New Issue
Block a user