diff --git a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py index 9d049a6d59d..34ceb56d129 100644 --- a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py +++ b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py @@ -104,7 +104,6 @@ class ResNet50Test(tf.test.TestCase): context.async_wait() self.assertEqual((2, 1000), output.shape) - @test_util.disable_tfrt('b/155260334') def test_apply(self): self._apply(defun=False) @@ -121,7 +120,6 @@ class ResNet50Test(tf.test.TestCase): def test_apply_with_defun_async(self): self._apply(defun=True, execution_mode=context.ASYNC) - @test_util.disable_tfrt('b/155260334') def test_apply_no_top(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) @@ -132,7 +130,6 @@ class ResNet50Test(tf.test.TestCase): if data_format == 'channels_first' else (2, 1, 1, 2048)) self.assertEqual(output_shape, output.shape) - @test_util.disable_tfrt('b/155260334') def test_apply_with_pooling(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') @@ -141,7 +138,6 @@ class ResNet50Test(tf.test.TestCase): output = model(images, training=False) self.assertEqual((2, 2048), output.shape) - @test_util.disable_tfrt('b/155260334') def test_apply_no_average_pooling(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50( @@ -153,7 +149,6 @@ class ResNet50Test(tf.test.TestCase): (2, 7, 7, 2048)) self.assertEqual(output_shape, output.shape) - @test_util.disable_tfrt('b/155260334') def test_apply_block3_strides(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50( @@ -165,7 +160,6 @@ class ResNet50Test(tf.test.TestCase): (2, 1, 1, 2048)) self.assertEqual(output_shape, output.shape) - @test_util.disable_tfrt('b/155260334') def test_apply_retrieve_intermediates(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50( @@ -220,7 +214,6 @@ class ResNet50Test(tf.test.TestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') - @test_util.disable_tfrt('b/155260334') def test_train(self): self._test_train() @@ -228,7 +221,6 @@ class ResNet50Test(tf.test.TestCase): def test_train_async(self): self._test_train(execution_mode=context.ASYNC) - @test_util.disable_tfrt('b/155260334') def test_no_garbage(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50(data_format)