diff --git a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py index 362fad1388c..30e2585e842 100644 --- a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py +++ b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py @@ -104,6 +104,7 @@ class ResNet50Test(tf.test.TestCase): context.async_wait() self.assertEqual((2, 1000), output.shape) + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply(self): self._apply(defun=False) @@ -120,6 +121,7 @@ class ResNet50Test(tf.test.TestCase): def test_apply_with_defun_async(self): self._apply(defun=True, execution_mode=context.ASYNC) + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply_no_top(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) @@ -130,6 +132,7 @@ class ResNet50Test(tf.test.TestCase): if data_format == 'channels_first' else (2, 1, 1, 2048)) self.assertEqual(output_shape, output.shape) + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply_with_pooling(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') @@ -138,6 +141,7 @@ class ResNet50Test(tf.test.TestCase): output = model(images, training=False) self.assertEqual((2, 2048), output.shape) + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply_no_average_pooling(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50( @@ -149,6 +153,7 @@ class ResNet50Test(tf.test.TestCase): (2, 7, 7, 2048)) self.assertEqual(output_shape, output.shape) + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply_block3_strides(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50( @@ -160,6 +165,7 @@ class ResNet50Test(tf.test.TestCase): (2, 1, 1, 2048)) self.assertEqual(output_shape, output.shape) + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply_retrieve_intermediates(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50( @@ -214,6 +220,7 @@ class ResNet50Test(tf.test.TestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') + @test_util.disable_tfrt('Flaky test. b/157103729') def test_train(self): self._test_train() @@ -221,6 +228,7 @@ class ResNet50Test(tf.test.TestCase): def test_train_async(self): self._test_train(execution_mode=context.ASYNC) + @test_util.disable_tfrt('Flaky test. b/157103729') def test_no_garbage(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50(data_format)