Remove default context for EagerGraphCombination.

It will incorrectly generate an extra graph context in following example:

@combinations.generate(combinations.combine(mode=['eager']))
class TestClass(test.TestCase, parameterized.TestCase):

  @combinations.generate(combinations.combine(foo=[1, 2]))
  def test_foo(self, foo):
    self.assertTrue(context.executing_eagerly())

Note that the combinations.generate for test_foo will silently add a graph context to the test case, since the default value for EagerGraphCombination is graph. This will override the eager context from the test class, and cause the assertion to fail.

Similar error will also raise for following case:

@combinations.generate(combinations.combine(foo=[1, 2]))
class TestClass(test.TestCase, parameterized.TestCase):

  def test_foo(self, foo):
    self.assertTrue(context.executing_eagerly())

if __name__ == '__main__':
  ops.enable_eager_execution()
  test.main()

Note that ops.enable_eager_execution() should force all the test case to run under eager context, but the silently added graph context is overriding the context value.

PiperOrigin-RevId: 300409634
Change-Id: If379db35a980193a49b5d74910e67be13e9af30c
This commit is contained in:
Scott Zhu 2020-03-11 14:39:12 -07:00 committed by TensorFlower Gardener
parent 37e7693c78
commit 15845c945e

View File

@ -31,16 +31,17 @@ from tensorflow.python.framework import test_combinations
class EagerGraphCombination(test_combinations.TestCombination):
"""Run the test in Graph or Eager mode. Graph is the default.
"""Run the test in Graph or Eager mode.
The optional `mode` parameter controls the test's execution mode. Its
accepted values are "graph" or "eager" literals.
"""
def context_managers(self, kwargs):
# TODO(isaprykin): Switch the default to eager.
mode = kwargs.pop("mode", "graph")
if mode == "eager":
mode = kwargs.pop("mode", None)
if mode is None:
return []
elif mode == "eager":
return [context.eager_mode()]
elif mode == "graph":
return [ops.Graph().as_default(), context.graph_mode()]