diff --git a/tensorflow/python/data/kernel_tests/concatenate_test.py b/tensorflow/python/data/kernel_tests/concatenate_test.py index 384fd289f16..bf726607681 100644 --- a/tensorflow/python/data/kernel_tests/concatenate_test.py +++ b/tensorflow/python/data/kernel_tests/concatenate_test.py @@ -17,20 +17,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class ConcatenateTest(test_base.DatasetTestBase): +class ConcatenateTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testConcatenateDataset(self): input_components = ( np.tile(np.array([[1], [2], [3], [4]]), 20), @@ -64,6 +65,7 @@ class ConcatenateTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testConcatenateDatasetDifferentShape(self): input_components = ( np.tile(np.array([[1], [2], [3], [4]]), 20), @@ -94,6 +96,7 @@ class ConcatenateTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next()) + @combinations.generate(test_base.default_test_combinations()) def testConcatenateDatasetDifferentStructure(self): input_components = ( np.tile(np.array([[1], [2], [3], [4]]), 5), @@ -110,6 +113,7 @@ class ConcatenateTest(test_base.DatasetTestBase): with self.assertRaisesRegexp(TypeError, "have different types"): input_dataset.concatenate(dataset_to_concatenate) + @combinations.generate(test_base.default_test_combinations()) def testConcatenateDatasetDifferentKeys(self): input_components = { "foo": np.array([[1], [2], [3], [4]]), @@ -127,6 +131,7 @@ class ConcatenateTest(test_base.DatasetTestBase): with self.assertRaisesRegexp(TypeError, "have different types"): input_dataset.concatenate(dataset_to_concatenate) + @combinations.generate(test_base.default_test_combinations()) def testConcatenateDatasetDifferentType(self): input_components = ( np.tile(np.array([[1], [2], [3], [4]]), 5),