Disable flaky TFRT tests.

PiperOrigin-RevId: 312415958
Change-Id: I5cdaab0a4c5c2e9cf09bcee61df3b008a98eac22
This commit is contained in:
Kibeom Kim 2020-05-19 21:38:42 -07:00 committed by TensorFlower Gardener
parent f8a797e13e
commit cd0322fa0e
1 changed files with 8 additions and 0 deletions

View File

@ -104,6 +104,7 @@ class ResNet50Test(tf.test.TestCase):
context.async_wait()
self.assertEqual((2, 1000), output.shape)
@test_util.disable_tfrt('Flaky test. b/157103729')
def test_apply(self):
self._apply(defun=False)
@ -120,6 +121,7 @@ class ResNet50Test(tf.test.TestCase):
def test_apply_with_defun_async(self):
self._apply(defun=True, execution_mode=context.ASYNC)
@test_util.disable_tfrt('Flaky test. b/157103729')
def test_apply_no_top(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False)
@ -130,6 +132,7 @@ class ResNet50Test(tf.test.TestCase):
if data_format == 'channels_first' else (2, 1, 1, 2048))
self.assertEqual(output_shape, output.shape)
@test_util.disable_tfrt('Flaky test. b/157103729')
def test_apply_with_pooling(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
@ -138,6 +141,7 @@ class ResNet50Test(tf.test.TestCase):
output = model(images, training=False)
self.assertEqual((2, 2048), output.shape)
@test_util.disable_tfrt('Flaky test. b/157103729')
def test_apply_no_average_pooling(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(
@ -149,6 +153,7 @@ class ResNet50Test(tf.test.TestCase):
(2, 7, 7, 2048))
self.assertEqual(output_shape, output.shape)
@test_util.disable_tfrt('Flaky test. b/157103729')
def test_apply_block3_strides(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(
@ -160,6 +165,7 @@ class ResNet50Test(tf.test.TestCase):
(2, 1, 1, 2048))
self.assertEqual(output_shape, output.shape)
@test_util.disable_tfrt('Flaky test. b/157103729')
def test_apply_retrieve_intermediates(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(
@ -214,6 +220,7 @@ class ResNet50Test(tf.test.TestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'loss')
@test_util.disable_tfrt('Flaky test. b/157103729')
def test_train(self):
self._test_train()
@ -221,6 +228,7 @@ class ResNet50Test(tf.test.TestCase):
def test_train_async(self):
self._test_train(execution_mode=context.ASYNC)
@test_util.disable_tfrt('Flaky test. b/157103729')
def test_no_garbage(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format)