Move combinations.py closer to TF's test_util.py.

Customizations are split into tf.distribute.*-specific and general TF.

PiperOrigin-RevId: 257903643
This commit is contained in:
Igor Saprykin 2019-07-12 18:14:31 -07:00 committed by TensorFlower Gardener
parent ab8a7e47f9
commit 42b8511f63
7 changed files with 111 additions and 64 deletions

View File

@ -113,6 +113,7 @@ py_library(
":cudnn_rnn_ops_gen",
":errors",
":framework",
":framework_combinations",
":framework_for_generated_wrappers",
":functional_ops",
":gradient_checker",
@ -1326,6 +1327,39 @@ py_library(
deps = [":framework_test_lib"],
)
py_library(
name = "framework_combinations",
srcs = ["framework/combinations.py"],
srcs_version = "PY2AND3",
deps = [
":framework_ops",
":framework_test_combinations_lib",
":util",
"//tensorflow/python/eager:context",
],
)
py_library(
name = "framework_test_combinations_lib",
srcs = ["framework/test_combinations.py"],
srcs_version = "PY2AND3",
deps = [
":util",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "test_combinations_test",
srcs = ["framework/test_combinations_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":framework_test_combinations_lib",
"//tensorflow/python/eager:test",
],
)
py_library(
name = "client_testlib",
srcs = ["platform/test.py"],
@ -6706,8 +6740,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":client_testlib",
":framework_combinations",
":tf2",
"//tensorflow/python/distribute:combinations",
],
)

View File

@ -20,7 +20,6 @@ py_library(
":single_loss_example",
":strategy_combinations",
":strategy_test_lib",
":test_combinations",
"//tensorflow/python/keras/distribute:keras_correctness_test_lib",
"//tensorflow/python/keras/distribute:keras_test_lib",
],
@ -593,41 +592,17 @@ py_library(
],
)
py_library(
name = "test_combinations",
srcs = ["test_combinations.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "combinations",
# TODO(isaprykin): Rename "combinations" to "tf_combinations" and
# "test_combinations" to "combinations".
srcs = ["combinations.py"],
srcs_version = "PY2AND3",
deps = [
":test_combinations",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python:framework_combinations",
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/eager:context",
],
)
py_test(
name = "test_combinations_test",
srcs = ["test_combinations_test.py"],
python_version = "PY2",
deps = [
":test_combinations",
"//tensorflow/python/eager:test",
"//tensorflow/python/keras:backend",
],
)
py_library(
name = "strategy_combinations",
srcs = ["strategy_combinations.py"],

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module customizes `test_combinations` for Tensorflow.
"""This module customizes `test_combinations` for `tf.distribute.Strategy`.
Additionally it provides `generate()`, `combine()` and `times()` with Tensorflow
customizations as a default.
Additionally it provides `generate()`, `combine()` and `times()` with
`tf.distribute.Strategy` customizations as a default.
"""
from __future__ import absolute_import
@ -25,9 +25,9 @@ from __future__ import print_function
import functools
import sys
from tensorflow.python.distribute import test_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import combinations as framework_combinations
from tensorflow.python.framework import test_combinations
# TODO(rchao): Rename `distribution` parameter to `strategy` or
@ -156,28 +156,6 @@ class TPUCombination(NamedTPUCombination):
] + NamedTPUCombination.parameter_modifiers(self)
class EagerGraphCombination(test_combinations.TestCombination):
"""Run the test in Graph or Eager mode. Graph is the default.
The optional `mode` parameter controls the test's execution mode. Its
accepted values are "graph" or "eager" literals.
"""
def context_managers(self, kwargs):
# TODO(isaprykin): Switch the default to eager.
mode = kwargs.pop("mode", "graph")
if mode == "eager":
return [context.eager_mode()]
elif mode == "graph":
return [ops.Graph().as_default(), context.graph_mode()]
else:
raise ValueError(
"'mode' has to be either 'eager' or 'graph' and not {}".format(mode))
def parameter_modifiers(self):
return [test_combinations.OptionalParameter("mode")]
class NamedDistribution(object):
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
@ -205,10 +183,11 @@ class NamedDistribution(object):
return self._name
_defaults = framework_combinations.generate.keywords["test_combinations"]
generate = functools.partial(
test_combinations.generate,
test_combinations=(EagerGraphCombination(), GPUCombination(),
TPUCombination()))
combine = test_combinations.combine
times = test_combinations.times
NamedObject = test_combinations.NamedObject
framework_combinations.generate,
test_combinations=_defaults + (GPUCombination(), TPUCombination()))
combine = framework_combinations.combine
times = framework_combinations.times
NamedObject = framework_combinations.NamedObject

View File

@ -0,0 +1,59 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module customizes `test_combinations` for Tensorflow.
Additionally it provides `generate()`, `combine()` and `times()` with Tensorflow
customizations as a default.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.python.framework import test_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
class EagerGraphCombination(test_combinations.TestCombination):
"""Run the test in Graph or Eager mode. Graph is the default.
The optional `mode` parameter controls the test's execution mode. Its
accepted values are "graph" or "eager" literals.
"""
def context_managers(self, kwargs):
# TODO(isaprykin): Switch the default to eager.
mode = kwargs.pop("mode", "graph")
if mode == "eager":
return [context.eager_mode()]
elif mode == "graph":
return [ops.Graph().as_default(), context.graph_mode()]
else:
raise ValueError(
"'mode' has to be either 'eager' or 'graph' and not {}".format(mode))
def parameter_modifiers(self):
return [test_combinations.OptionalParameter("mode")]
generate = functools.partial(
test_combinations.generate,
test_combinations=(EagerGraphCombination(),))
combine = test_combinations.combine
times = test_combinations.times
NamedObject = test_combinations.NamedObject

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for some testing utils from strategy_test_lib."""
"""Tests generating test combinations."""
from __future__ import absolute_import
from __future__ import division
@ -22,7 +22,7 @@ from collections import OrderedDict
from absl.testing import parameterized
from tensorflow.python.distribute import test_combinations as combinations
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.eager import test

View File

@ -23,7 +23,7 @@ import os
from absl.testing import parameterized
from tensorflow.python import tf2
from tensorflow.python.distribute import combinations
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test