Use combinations to test eager/graph mode and TF v1/v2 for concatenate_test.py
PiperOrigin-RevId: 261198136
This commit is contained in:
parent
a7448908b2
commit
76680cac29
@ -17,20 +17,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ConcatenateTest(test_base.DatasetTestBase):
|
||||
class ConcatenateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcatenateDataset(self):
|
||||
input_components = (
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
@ -64,6 +65,7 @@ class ConcatenateTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcatenateDatasetDifferentShape(self):
|
||||
input_components = (
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
@ -94,6 +96,7 @@ class ConcatenateTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcatenateDatasetDifferentStructure(self):
|
||||
input_components = (
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 5),
|
||||
@ -110,6 +113,7 @@ class ConcatenateTest(test_base.DatasetTestBase):
|
||||
with self.assertRaisesRegexp(TypeError, "have different types"):
|
||||
input_dataset.concatenate(dataset_to_concatenate)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcatenateDatasetDifferentKeys(self):
|
||||
input_components = {
|
||||
"foo": np.array([[1], [2], [3], [4]]),
|
||||
@ -127,6 +131,7 @@ class ConcatenateTest(test_base.DatasetTestBase):
|
||||
with self.assertRaisesRegexp(TypeError, "have different types"):
|
||||
input_dataset.concatenate(dataset_to_concatenate)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testConcatenateDatasetDifferentType(self):
|
||||
input_components = (
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 5),
|
||||
|
Loading…
Reference in New Issue
Block a user