Use combinations to test eager/graph mode and TF v1/v2 for concatenate_test.py

PiperOrigin-RevId: 261198136
This commit is contained in:
Andrew Audibert 2019-08-01 14:28:58 -07:00 committed by TensorFlower Gardener
parent a7448908b2
commit 76680cac29

View File

@ -17,20 +17,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ConcatenateTest(test_base.DatasetTestBase):
class ConcatenateTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testConcatenateDataset(self):
input_components = (
np.tile(np.array([[1], [2], [3], [4]]), 20),
@ -64,6 +65,7 @@ class ConcatenateTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testConcatenateDatasetDifferentShape(self):
input_components = (
np.tile(np.array([[1], [2], [3], [4]]), 20),
@ -94,6 +96,7 @@ class ConcatenateTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testConcatenateDatasetDifferentStructure(self):
input_components = (
np.tile(np.array([[1], [2], [3], [4]]), 5),
@ -110,6 +113,7 @@ class ConcatenateTest(test_base.DatasetTestBase):
with self.assertRaisesRegexp(TypeError, "have different types"):
input_dataset.concatenate(dataset_to_concatenate)
@combinations.generate(test_base.default_test_combinations())
def testConcatenateDatasetDifferentKeys(self):
input_components = {
"foo": np.array([[1], [2], [3], [4]]),
@ -127,6 +131,7 @@ class ConcatenateTest(test_base.DatasetTestBase):
with self.assertRaisesRegexp(TypeError, "have different types"):
input_dataset.concatenate(dataset_to_concatenate)
@combinations.generate(test_base.default_test_combinations())
def testConcatenateDatasetDifferentType(self):
input_components = (
np.tile(np.array([[1], [2], [3], [4]]), 5),