Move combinations.py closer to TF's test_util.py.

Customizations are split into tf.distribute.*-specific and general TF.

PiperOrigin-RevId: 257903643
This commit is contained in:
Igor Saprykin 2019-07-12 18:14:31 -07:00 committed by TensorFlower Gardener
parent ab8a7e47f9
commit 42b8511f63
7 changed files with 111 additions and 64 deletions

View File

@ -113,6 +113,7 @@ py_library(
":cudnn_rnn_ops_gen", ":cudnn_rnn_ops_gen",
":errors", ":errors",
":framework", ":framework",
":framework_combinations",
":framework_for_generated_wrappers", ":framework_for_generated_wrappers",
":functional_ops", ":functional_ops",
":gradient_checker", ":gradient_checker",
@ -1326,6 +1327,39 @@ py_library(
deps = [":framework_test_lib"], deps = [":framework_test_lib"],
) )
py_library(
name = "framework_combinations",
srcs = ["framework/combinations.py"],
srcs_version = "PY2AND3",
deps = [
":framework_ops",
":framework_test_combinations_lib",
":util",
"//tensorflow/python/eager:context",
],
)
py_library(
name = "framework_test_combinations_lib",
srcs = ["framework/test_combinations.py"],
srcs_version = "PY2AND3",
deps = [
":util",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "test_combinations_test",
srcs = ["framework/test_combinations_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":framework_test_combinations_lib",
"//tensorflow/python/eager:test",
],
)
py_library( py_library(
name = "client_testlib", name = "client_testlib",
srcs = ["platform/test.py"], srcs = ["platform/test.py"],
@ -6706,8 +6740,8 @@ py_test(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":client_testlib", ":client_testlib",
":framework_combinations",
":tf2", ":tf2",
"//tensorflow/python/distribute:combinations",
], ],
) )

View File

@ -20,7 +20,6 @@ py_library(
":single_loss_example", ":single_loss_example",
":strategy_combinations", ":strategy_combinations",
":strategy_test_lib", ":strategy_test_lib",
":test_combinations",
"//tensorflow/python/keras/distribute:keras_correctness_test_lib", "//tensorflow/python/keras/distribute:keras_correctness_test_lib",
"//tensorflow/python/keras/distribute:keras_test_lib", "//tensorflow/python/keras/distribute:keras_test_lib",
], ],
@ -593,41 +592,17 @@ py_library(
], ],
) )
py_library(
name = "test_combinations",
srcs = ["test_combinations.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"@absl_py//absl/testing:parameterized",
],
)
py_library( py_library(
name = "combinations", name = "combinations",
# TODO(isaprykin): Rename "combinations" to "tf_combinations" and
# "test_combinations" to "combinations".
srcs = ["combinations.py"], srcs = ["combinations.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":test_combinations", "//tensorflow/python:framework_combinations",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python:util",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
], ],
) )
py_test(
name = "test_combinations_test",
srcs = ["test_combinations_test.py"],
python_version = "PY2",
deps = [
":test_combinations",
"//tensorflow/python/eager:test",
"//tensorflow/python/keras:backend",
],
)
py_library( py_library(
name = "strategy_combinations", name = "strategy_combinations",
srcs = ["strategy_combinations.py"], srcs = ["strategy_combinations.py"],

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""This module customizes `test_combinations` for Tensorflow. """This module customizes `test_combinations` for `tf.distribute.Strategy`.
Additionally it provides `generate()`, `combine()` and `times()` with Tensorflow Additionally it provides `generate()`, `combine()` and `times()` with
customizations as a default. `tf.distribute.Strategy` customizations as a default.
""" """
from __future__ import absolute_import from __future__ import absolute_import
@ -25,9 +25,9 @@ from __future__ import print_function
import functools import functools
import sys import sys
from tensorflow.python.distribute import test_combinations
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import ops from tensorflow.python.framework import combinations as framework_combinations
from tensorflow.python.framework import test_combinations
# TODO(rchao): Rename `distribution` parameter to `strategy` or # TODO(rchao): Rename `distribution` parameter to `strategy` or
@ -156,28 +156,6 @@ class TPUCombination(NamedTPUCombination):
] + NamedTPUCombination.parameter_modifiers(self) ] + NamedTPUCombination.parameter_modifiers(self)
class EagerGraphCombination(test_combinations.TestCombination):
"""Run the test in Graph or Eager mode. Graph is the default.
The optional `mode` parameter controls the test's execution mode. Its
accepted values are "graph" or "eager" literals.
"""
def context_managers(self, kwargs):
# TODO(isaprykin): Switch the default to eager.
mode = kwargs.pop("mode", "graph")
if mode == "eager":
return [context.eager_mode()]
elif mode == "graph":
return [ops.Graph().as_default(), context.graph_mode()]
else:
raise ValueError(
"'mode' has to be either 'eager' or 'graph' and not {}".format(mode))
def parameter_modifiers(self):
return [test_combinations.OptionalParameter("mode")]
class NamedDistribution(object): class NamedDistribution(object):
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles.""" """Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
@ -205,10 +183,11 @@ class NamedDistribution(object):
return self._name return self._name
_defaults = framework_combinations.generate.keywords["test_combinations"]
generate = functools.partial( generate = functools.partial(
test_combinations.generate, framework_combinations.generate,
test_combinations=(EagerGraphCombination(), GPUCombination(), test_combinations=_defaults + (GPUCombination(), TPUCombination()))
TPUCombination())) combine = framework_combinations.combine
combine = test_combinations.combine times = framework_combinations.times
times = test_combinations.times NamedObject = framework_combinations.NamedObject
NamedObject = test_combinations.NamedObject

View File

@ -0,0 +1,59 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module customizes `test_combinations` for Tensorflow.
Additionally it provides `generate()`, `combine()` and `times()` with Tensorflow
customizations as a default.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.python.framework import test_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
class EagerGraphCombination(test_combinations.TestCombination):
"""Run the test in Graph or Eager mode. Graph is the default.
The optional `mode` parameter controls the test's execution mode. Its
accepted values are "graph" or "eager" literals.
"""
def context_managers(self, kwargs):
# TODO(isaprykin): Switch the default to eager.
mode = kwargs.pop("mode", "graph")
if mode == "eager":
return [context.eager_mode()]
elif mode == "graph":
return [ops.Graph().as_default(), context.graph_mode()]
else:
raise ValueError(
"'mode' has to be either 'eager' or 'graph' and not {}".format(mode))
def parameter_modifiers(self):
return [test_combinations.OptionalParameter("mode")]
generate = functools.partial(
test_combinations.generate,
test_combinations=(EagerGraphCombination(),))
combine = test_combinations.combine
times = test_combinations.times
NamedObject = test_combinations.NamedObject

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for some testing utils from strategy_test_lib.""" """Tests generating test combinations."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -22,7 +22,7 @@ from collections import OrderedDict
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.distribute import test_combinations as combinations from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.eager import test from tensorflow.python.eager import test

View File

@ -23,7 +23,7 @@ import os
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.distribute import combinations from tensorflow.python.framework import combinations
from tensorflow.python.platform import test from tensorflow.python.platform import test