Update tests under keras.utils to use combinations.

Change all test_util.run_all_in_graph_and_eager_modes to combination.

PiperOrigin-RevId: 301393990
Change-Id: I7084404a9a256a11804bd474d1383f9c36de7305
This commit is contained in:
Scott Zhu 2020-03-17 09:53:03 -07:00 committed by TensorFlower Gardener
parent c3c2104933
commit 9f0ff44f9f
5 changed files with 17 additions and 9 deletions

View File

@ -102,6 +102,7 @@ tf_py_test(
":model_subclassing_test_util", ":model_subclassing_test_util",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//third_party/py/numpy", "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import copy import copy
import os import os
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
@ -29,6 +30,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.tests import model_subclassing_test_util as model_util from tensorflow.python.keras.tests import model_subclassing_test_util as model_util
@ -606,8 +608,8 @@ class GraphSpecificModelSubclassingTests(test.TestCase):
_ = model.evaluate([x1, x2], [y1, y2], verbose=0) _ = model.evaluate([x1, x2], [y1, y2], verbose=0)
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class CustomCallSignatureTests(test.TestCase): class CustomCallSignatureTests(test.TestCase, parameterized.TestCase):
def test_no_inputs_in_signature(self): def test_no_inputs_in_signature(self):
model = model_util.CustomCallModel() model = model_util.CustomCallModel()
@ -669,7 +671,7 @@ class CustomCallSignatureTests(test.TestCase):
arg = array_ops.ones([1]) arg = array_ops.ones([1])
model(arg, a=3) model(arg, a=3)
if not context.executing_eagerly(): if not context.executing_eagerly():
self.assertEqual(len(model.inputs), 1) self.assertLen(model.inputs, 1)
@test_util.assert_no_new_tensors @test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created @test_util.assert_no_garbage_created

View File

@ -244,6 +244,7 @@ tf_py_test(
":tf_utils", ":tf_utils",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
], ],
) )
@ -370,6 +371,7 @@ tf_py_test(
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_tensor", "//tensorflow/python/ops/ragged:ragged_tensor",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",

View File

@ -23,6 +23,7 @@ from absl.testing import parameterized
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras.utils import metrics_utils from tensorflow.python.keras.utils import metrics_utils
from tensorflow.python.ops import script_ops from tensorflow.python.ops import script_ops
from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_factory_ops
@ -30,7 +31,7 @@ from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class RaggedSizeOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): class RaggedSizeOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.parameters([ @parameterized.parameters([
@ -249,8 +250,8 @@ class RaggedSizeOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y]) metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class FilterTopKTest(test_util.TensorFlowTestCase): class FilterTopKTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def test_one_dimensional(self): def test_one_dimensional(self):
x = constant_op.constant([.3, .1, .2, -.5, 42.]) x = constant_op.constant([.3, .1, .2, -.5, 42.])

View File

@ -18,18 +18,20 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.keras import combinations
from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TestIsSymbolicTensor(test.TestCase): class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase):
def test_default_behavior(self): def test_default_behavior(self):
if context.executing_eagerly(): if context.executing_eagerly():