Move combinations.py closer to TF's test_util.py.
Customizations are split into tf.distribute.*-specific and general TF. PiperOrigin-RevId: 257903643
This commit is contained in:
parent
ab8a7e47f9
commit
42b8511f63
tensorflow/python
@ -113,6 +113,7 @@ py_library(
|
||||
":cudnn_rnn_ops_gen",
|
||||
":errors",
|
||||
":framework",
|
||||
":framework_combinations",
|
||||
":framework_for_generated_wrappers",
|
||||
":functional_ops",
|
||||
":gradient_checker",
|
||||
@ -1326,6 +1327,39 @@ py_library(
|
||||
deps = [":framework_test_lib"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "framework_combinations",
|
||||
srcs = ["framework/combinations.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_ops",
|
||||
":framework_test_combinations_lib",
|
||||
":util",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "framework_test_combinations_lib",
|
||||
srcs = ["framework/test_combinations.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":util",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_combinations_test",
|
||||
srcs = ["framework/test_combinations_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_combinations_lib",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "client_testlib",
|
||||
srcs = ["platform/test.py"],
|
||||
@ -6706,8 +6740,8 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":client_testlib",
|
||||
":framework_combinations",
|
||||
":tf2",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -20,7 +20,6 @@ py_library(
|
||||
":single_loss_example",
|
||||
":strategy_combinations",
|
||||
":strategy_test_lib",
|
||||
":test_combinations",
|
||||
"//tensorflow/python/keras/distribute:keras_correctness_test_lib",
|
||||
"//tensorflow/python/keras/distribute:keras_test_lib",
|
||||
],
|
||||
@ -593,41 +592,17 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_combinations",
|
||||
srcs = ["test_combinations.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "combinations",
|
||||
# TODO(isaprykin): Rename "combinations" to "tf_combinations" and
|
||||
# "test_combinations" to "combinations".
|
||||
srcs = ["combinations.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":test_combinations",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:framework_combinations",
|
||||
"//tensorflow/python:framework_test_combinations_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_combinations_test",
|
||||
srcs = ["test_combinations_test.py"],
|
||||
python_version = "PY2",
|
||||
deps = [
|
||||
":test_combinations",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras:backend",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "strategy_combinations",
|
||||
srcs = ["strategy_combinations.py"],
|
||||
|
@ -12,10 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""This module customizes `test_combinations` for Tensorflow.
|
||||
"""This module customizes `test_combinations` for `tf.distribute.Strategy`.
|
||||
|
||||
Additionally it provides `generate()`, `combine()` and `times()` with Tensorflow
|
||||
customizations as a default.
|
||||
Additionally it provides `generate()`, `combine()` and `times()` with
|
||||
`tf.distribute.Strategy` customizations as a default.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -25,9 +25,9 @@ from __future__ import print_function
|
||||
import functools
|
||||
import sys
|
||||
|
||||
from tensorflow.python.distribute import test_combinations
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import combinations as framework_combinations
|
||||
from tensorflow.python.framework import test_combinations
|
||||
|
||||
|
||||
# TODO(rchao): Rename `distribution` parameter to `strategy` or
|
||||
@ -156,28 +156,6 @@ class TPUCombination(NamedTPUCombination):
|
||||
] + NamedTPUCombination.parameter_modifiers(self)
|
||||
|
||||
|
||||
class EagerGraphCombination(test_combinations.TestCombination):
|
||||
"""Run the test in Graph or Eager mode. Graph is the default.
|
||||
|
||||
The optional `mode` parameter controls the test's execution mode. Its
|
||||
accepted values are "graph" or "eager" literals.
|
||||
"""
|
||||
|
||||
def context_managers(self, kwargs):
|
||||
# TODO(isaprykin): Switch the default to eager.
|
||||
mode = kwargs.pop("mode", "graph")
|
||||
if mode == "eager":
|
||||
return [context.eager_mode()]
|
||||
elif mode == "graph":
|
||||
return [ops.Graph().as_default(), context.graph_mode()]
|
||||
else:
|
||||
raise ValueError(
|
||||
"'mode' has to be either 'eager' or 'graph' and not {}".format(mode))
|
||||
|
||||
def parameter_modifiers(self):
|
||||
return [test_combinations.OptionalParameter("mode")]
|
||||
|
||||
|
||||
class NamedDistribution(object):
|
||||
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
|
||||
|
||||
@ -205,10 +183,11 @@ class NamedDistribution(object):
|
||||
return self._name
|
||||
|
||||
|
||||
_defaults = framework_combinations.generate.keywords["test_combinations"]
|
||||
|
||||
generate = functools.partial(
|
||||
test_combinations.generate,
|
||||
test_combinations=(EagerGraphCombination(), GPUCombination(),
|
||||
TPUCombination()))
|
||||
combine = test_combinations.combine
|
||||
times = test_combinations.times
|
||||
NamedObject = test_combinations.NamedObject
|
||||
framework_combinations.generate,
|
||||
test_combinations=_defaults + (GPUCombination(), TPUCombination()))
|
||||
combine = framework_combinations.combine
|
||||
times = framework_combinations.times
|
||||
NamedObject = framework_combinations.NamedObject
|
||||
|
59
tensorflow/python/framework/combinations.py
Normal file
59
tensorflow/python/framework/combinations.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""This module customizes `test_combinations` for Tensorflow.
|
||||
|
||||
Additionally it provides `generate()`, `combine()` and `times()` with Tensorflow
|
||||
customizations as a default.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.python.framework import test_combinations
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
class EagerGraphCombination(test_combinations.TestCombination):
|
||||
"""Run the test in Graph or Eager mode. Graph is the default.
|
||||
|
||||
The optional `mode` parameter controls the test's execution mode. Its
|
||||
accepted values are "graph" or "eager" literals.
|
||||
"""
|
||||
|
||||
def context_managers(self, kwargs):
|
||||
# TODO(isaprykin): Switch the default to eager.
|
||||
mode = kwargs.pop("mode", "graph")
|
||||
if mode == "eager":
|
||||
return [context.eager_mode()]
|
||||
elif mode == "graph":
|
||||
return [ops.Graph().as_default(), context.graph_mode()]
|
||||
else:
|
||||
raise ValueError(
|
||||
"'mode' has to be either 'eager' or 'graph' and not {}".format(mode))
|
||||
|
||||
def parameter_modifiers(self):
|
||||
return [test_combinations.OptionalParameter("mode")]
|
||||
|
||||
|
||||
generate = functools.partial(
|
||||
test_combinations.generate,
|
||||
test_combinations=(EagerGraphCombination(),))
|
||||
combine = test_combinations.combine
|
||||
times = test_combinations.times
|
||||
NamedObject = test_combinations.NamedObject
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for some testing utils from strategy_test_lib."""
|
||||
"""Tests generating test combinations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -22,7 +22,7 @@ from collections import OrderedDict
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.distribute import test_combinations as combinations
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.eager import test
|
||||
|
||||
|
@ -23,7 +23,7 @@ import os
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user