It will incorrectly generate an extra graph context in following example: @combinations.generate(combinations.combine(mode=['eager'])) class TestClass(test.TestCase, parameterized.TestCase): @combinations.generate(combinations.combine(foo=[1, 2])) def test_foo(self, foo): self.assertTrue(context.executing_eagerly()) Note that the combinations.generate for test_foo will silently add a graph context to the test case, since the default value for EagerGraphCombination is graph. This will override the eager context from the test class, and cause the assertion to fail. Similar error will also raise for following case: @combinations.generate(combinations.combine(foo=[1, 2])) class TestClass(test.TestCase, parameterized.TestCase): def test_foo(self, foo): self.assertTrue(context.executing_eagerly()) if __name__ == '__main__': ops.enable_eager_execution() test.main() Note that ops.enable_eager_execution() should force all the test case to run under eager context, but the silently added graph context is overriding the context value. PiperOrigin-RevId: 300409634 Change-Id: If379db35a980193a49b5d74910e67be13e9af30c
84 lines
2.9 KiB
Python
84 lines
2.9 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""This module customizes `test_combinations` for Tensorflow.
|
|
|
|
Additionally it provides `generate()`, `combine()` and `times()` with Tensorflow
|
|
customizations as a default.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import functools
|
|
|
|
from tensorflow.python import tf2
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import test_combinations
|
|
|
|
|
|
class EagerGraphCombination(test_combinations.TestCombination):
|
|
"""Run the test in Graph or Eager mode.
|
|
|
|
The optional `mode` parameter controls the test's execution mode. Its
|
|
accepted values are "graph" or "eager" literals.
|
|
"""
|
|
|
|
def context_managers(self, kwargs):
|
|
mode = kwargs.pop("mode", None)
|
|
if mode is None:
|
|
return []
|
|
elif mode == "eager":
|
|
return [context.eager_mode()]
|
|
elif mode == "graph":
|
|
return [ops.Graph().as_default(), context.graph_mode()]
|
|
else:
|
|
raise ValueError(
|
|
"'mode' has to be either 'eager' or 'graph' and not {}".format(mode))
|
|
|
|
def parameter_modifiers(self):
|
|
return [test_combinations.OptionalParameter("mode")]
|
|
|
|
|
|
class TFVersionCombination(test_combinations.TestCombination):
|
|
"""Control the execution of the test in TF1.x and TF2.
|
|
|
|
If TF2 is enabled then a test with TF1 test is going to be skipped and vice
|
|
versa.
|
|
|
|
Test targets continuously run in TF2 thanks to the tensorflow.v2 TAP target.
|
|
A test can be run in TF2 with bazel by passing --test_env=TF2_BEHAVIOR=1.
|
|
"""
|
|
|
|
def should_execute_combination(self, kwargs):
|
|
tf_api_version = kwargs.pop("tf_api_version", None)
|
|
if tf_api_version == 1 and tf2.enabled():
|
|
return (False, "Skipping a TF1.x test when TF2 is enabled.")
|
|
elif tf_api_version == 2 and not tf2.enabled():
|
|
return (False, "Skipping a TF2 test when TF2 is not enabled.")
|
|
return (True, None)
|
|
|
|
def parameter_modifiers(self):
|
|
return [test_combinations.OptionalParameter("tf_api_version")]
|
|
|
|
|
|
generate = functools.partial(
|
|
test_combinations.generate,
|
|
test_combinations=(EagerGraphCombination(), TFVersionCombination()))
|
|
combine = test_combinations.combine
|
|
times = test_combinations.times
|
|
NamedObject = test_combinations.NamedObject
|