Disable flaky TFRT tests.
PiperOrigin-RevId: 312415958 Change-Id: I5cdaab0a4c5c2e9cf09bcee61df3b008a98eac22
This commit is contained in:
parent
f8a797e13e
commit
cd0322fa0e
|
@ -104,6 +104,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||
context.async_wait()
|
||||
self.assertEqual((2, 1000), output.shape)
|
||||
|
||||
@test_util.disable_tfrt('Flaky test. b/157103729')
|
||||
def test_apply(self):
|
||||
self._apply(defun=False)
|
||||
|
||||
|
@ -120,6 +121,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||
def test_apply_with_defun_async(self):
|
||||
self._apply(defun=True, execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt('Flaky test. b/157103729')
|
||||
def test_apply_no_top(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format, include_top=False)
|
||||
|
@ -130,6 +132,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||
if data_format == 'channels_first' else (2, 1, 1, 2048))
|
||||
self.assertEqual(output_shape, output.shape)
|
||||
|
||||
@test_util.disable_tfrt('Flaky test. b/157103729')
|
||||
def test_apply_with_pooling(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
|
||||
|
@ -138,6 +141,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||
output = model(images, training=False)
|
||||
self.assertEqual((2, 2048), output.shape)
|
||||
|
||||
@test_util.disable_tfrt('Flaky test. b/157103729')
|
||||
def test_apply_no_average_pooling(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(
|
||||
|
@ -149,6 +153,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||
(2, 7, 7, 2048))
|
||||
self.assertEqual(output_shape, output.shape)
|
||||
|
||||
@test_util.disable_tfrt('Flaky test. b/157103729')
|
||||
def test_apply_block3_strides(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(
|
||||
|
@ -160,6 +165,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||
(2, 1, 1, 2048))
|
||||
self.assertEqual(output_shape, output.shape)
|
||||
|
||||
@test_util.disable_tfrt('Flaky test. b/157103729')
|
||||
def test_apply_retrieve_intermediates(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(
|
||||
|
@ -214,6 +220,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].tag, 'loss')
|
||||
|
||||
@test_util.disable_tfrt('Flaky test. b/157103729')
|
||||
def test_train(self):
|
||||
self._test_train()
|
||||
|
||||
|
@ -221,6 +228,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||
def test_train_async(self):
|
||||
self._test_train(execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt('Flaky test. b/157103729')
|
||||
def test_no_garbage(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format)
|
||||
|
|
Loading…
Reference in New Issue