Move combinations.py closer to TF's test_util.py.
Customizations are split into tf.distribute.*-specific and general TF. PiperOrigin-RevId: 257903643
This commit is contained in:
parent
ab8a7e47f9
commit
42b8511f63
@ -113,6 +113,7 @@ py_library(
|
|||||||
":cudnn_rnn_ops_gen",
|
":cudnn_rnn_ops_gen",
|
||||||
":errors",
|
":errors",
|
||||||
":framework",
|
":framework",
|
||||||
|
":framework_combinations",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
":functional_ops",
|
":functional_ops",
|
||||||
":gradient_checker",
|
":gradient_checker",
|
||||||
@ -1326,6 +1327,39 @@ py_library(
|
|||||||
deps = [":framework_test_lib"],
|
deps = [":framework_test_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "framework_combinations",
|
||||||
|
srcs = ["framework/combinations.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":framework_ops",
|
||||||
|
":framework_test_combinations_lib",
|
||||||
|
":util",
|
||||||
|
"//tensorflow/python/eager:context",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "framework_test_combinations_lib",
|
||||||
|
srcs = ["framework/test_combinations.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":util",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_combinations_test",
|
||||||
|
srcs = ["framework/test_combinations_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":framework_test_combinations_lib",
|
||||||
|
"//tensorflow/python/eager:test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "client_testlib",
|
name = "client_testlib",
|
||||||
srcs = ["platform/test.py"],
|
srcs = ["platform/test.py"],
|
||||||
@ -6706,8 +6740,8 @@ py_test(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":client_testlib",
|
":client_testlib",
|
||||||
|
":framework_combinations",
|
||||||
":tf2",
|
":tf2",
|
||||||
"//tensorflow/python/distribute:combinations",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ py_library(
|
|||||||
":single_loss_example",
|
":single_loss_example",
|
||||||
":strategy_combinations",
|
":strategy_combinations",
|
||||||
":strategy_test_lib",
|
":strategy_test_lib",
|
||||||
":test_combinations",
|
|
||||||
"//tensorflow/python/keras/distribute:keras_correctness_test_lib",
|
"//tensorflow/python/keras/distribute:keras_correctness_test_lib",
|
||||||
"//tensorflow/python/keras/distribute:keras_test_lib",
|
"//tensorflow/python/keras/distribute:keras_test_lib",
|
||||||
],
|
],
|
||||||
@ -593,41 +592,17 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "test_combinations",
|
|
||||||
srcs = ["test_combinations.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/python:util",
|
|
||||||
"@absl_py//absl/testing:parameterized",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "combinations",
|
name = "combinations",
|
||||||
# TODO(isaprykin): Rename "combinations" to "tf_combinations" and
|
|
||||||
# "test_combinations" to "combinations".
|
|
||||||
srcs = ["combinations.py"],
|
srcs = ["combinations.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":test_combinations",
|
"//tensorflow/python:framework_combinations",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_test_combinations_lib",
|
||||||
"//tensorflow/python:util",
|
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "test_combinations_test",
|
|
||||||
srcs = ["test_combinations_test.py"],
|
|
||||||
python_version = "PY2",
|
|
||||||
deps = [
|
|
||||||
":test_combinations",
|
|
||||||
"//tensorflow/python/eager:test",
|
|
||||||
"//tensorflow/python/keras:backend",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "strategy_combinations",
|
name = "strategy_combinations",
|
||||||
srcs = ["strategy_combinations.py"],
|
srcs = ["strategy_combinations.py"],
|
||||||
|
@ -12,10 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""This module customizes `test_combinations` for Tensorflow.
|
"""This module customizes `test_combinations` for `tf.distribute.Strategy`.
|
||||||
|
|
||||||
Additionally it provides `generate()`, `combine()` and `times()` with Tensorflow
|
Additionally it provides `generate()`, `combine()` and `times()` with
|
||||||
customizations as a default.
|
`tf.distribute.Strategy` customizations as a default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
@ -25,9 +25,9 @@ from __future__ import print_function
|
|||||||
import functools
|
import functools
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from tensorflow.python.distribute import test_combinations
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import combinations as framework_combinations
|
||||||
|
from tensorflow.python.framework import test_combinations
|
||||||
|
|
||||||
|
|
||||||
# TODO(rchao): Rename `distribution` parameter to `strategy` or
|
# TODO(rchao): Rename `distribution` parameter to `strategy` or
|
||||||
@ -156,28 +156,6 @@ class TPUCombination(NamedTPUCombination):
|
|||||||
] + NamedTPUCombination.parameter_modifiers(self)
|
] + NamedTPUCombination.parameter_modifiers(self)
|
||||||
|
|
||||||
|
|
||||||
class EagerGraphCombination(test_combinations.TestCombination):
|
|
||||||
"""Run the test in Graph or Eager mode. Graph is the default.
|
|
||||||
|
|
||||||
The optional `mode` parameter controls the test's execution mode. Its
|
|
||||||
accepted values are "graph" or "eager" literals.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def context_managers(self, kwargs):
|
|
||||||
# TODO(isaprykin): Switch the default to eager.
|
|
||||||
mode = kwargs.pop("mode", "graph")
|
|
||||||
if mode == "eager":
|
|
||||||
return [context.eager_mode()]
|
|
||||||
elif mode == "graph":
|
|
||||||
return [ops.Graph().as_default(), context.graph_mode()]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"'mode' has to be either 'eager' or 'graph' and not {}".format(mode))
|
|
||||||
|
|
||||||
def parameter_modifiers(self):
|
|
||||||
return [test_combinations.OptionalParameter("mode")]
|
|
||||||
|
|
||||||
|
|
||||||
class NamedDistribution(object):
|
class NamedDistribution(object):
|
||||||
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
|
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
|
||||||
|
|
||||||
@ -205,10 +183,11 @@ class NamedDistribution(object):
|
|||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
|
|
||||||
|
_defaults = framework_combinations.generate.keywords["test_combinations"]
|
||||||
|
|
||||||
generate = functools.partial(
|
generate = functools.partial(
|
||||||
test_combinations.generate,
|
framework_combinations.generate,
|
||||||
test_combinations=(EagerGraphCombination(), GPUCombination(),
|
test_combinations=_defaults + (GPUCombination(), TPUCombination()))
|
||||||
TPUCombination()))
|
combine = framework_combinations.combine
|
||||||
combine = test_combinations.combine
|
times = framework_combinations.times
|
||||||
times = test_combinations.times
|
NamedObject = framework_combinations.NamedObject
|
||||||
NamedObject = test_combinations.NamedObject
|
|
||||||
|
59
tensorflow/python/framework/combinations.py
Normal file
59
tensorflow/python/framework/combinations.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""This module customizes `test_combinations` for Tensorflow.
|
||||||
|
|
||||||
|
Additionally it provides `generate()`, `combine()` and `times()` with Tensorflow
|
||||||
|
customizations as a default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from tensorflow.python.framework import test_combinations
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
|
||||||
|
|
||||||
|
class EagerGraphCombination(test_combinations.TestCombination):
|
||||||
|
"""Run the test in Graph or Eager mode. Graph is the default.
|
||||||
|
|
||||||
|
The optional `mode` parameter controls the test's execution mode. Its
|
||||||
|
accepted values are "graph" or "eager" literals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def context_managers(self, kwargs):
|
||||||
|
# TODO(isaprykin): Switch the default to eager.
|
||||||
|
mode = kwargs.pop("mode", "graph")
|
||||||
|
if mode == "eager":
|
||||||
|
return [context.eager_mode()]
|
||||||
|
elif mode == "graph":
|
||||||
|
return [ops.Graph().as_default(), context.graph_mode()]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"'mode' has to be either 'eager' or 'graph' and not {}".format(mode))
|
||||||
|
|
||||||
|
def parameter_modifiers(self):
|
||||||
|
return [test_combinations.OptionalParameter("mode")]
|
||||||
|
|
||||||
|
|
||||||
|
generate = functools.partial(
|
||||||
|
test_combinations.generate,
|
||||||
|
test_combinations=(EagerGraphCombination(),))
|
||||||
|
combine = test_combinations.combine
|
||||||
|
times = test_combinations.times
|
||||||
|
NamedObject = test_combinations.NamedObject
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for some testing utils from strategy_test_lib."""
|
"""Tests generating test combinations."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -22,7 +22,7 @@ from collections import OrderedDict
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.distribute import test_combinations as combinations
|
from tensorflow.python.framework import test_combinations as combinations
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
|
||||||
|
|
@ -23,7 +23,7 @@ import os
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user