Enable resnet50 tfrt target. The actual tests are largely disabled.

PiperOrigin-RevId: 308748769
Change-Id: I65a2795bca7fcec898511a6d53d46acdc0d3e75d
This commit is contained in:
Kibeom Kim 2020-04-27 19:58:19 -07:00 committed by TensorFlower Gardener
parent 75f58fcdde
commit 4f0dd967b1
2 changed files with 25 additions and 4 deletions

View File

@ -46,6 +46,7 @@ cuda_py_test(
"oss_serial",
"v1only",
],
tfrt_enabled = True,
deps = [
":resnet50",
":resnet50_test_util",

View File

@ -31,6 +31,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.eager.benchmarks.resnet50 import resnet50
from tensorflow.python.eager.benchmarks.resnet50 import resnet50_test_util
from tensorflow.python.framework import test_util
def compute_gradients(model, images, labels, num_replicas=1):
@ -103,18 +104,23 @@ class ResNet50Test(tf.test.TestCase):
context.async_wait()
self.assertEqual((2, 1000), output.shape)
@test_util.disable_tfrt('b/154858769')
def test_apply(self):
self._apply(defun=False)
@test_util.disable_tfrt('b/154858769')
def test_apply_async(self):
self._apply(defun=False, execution_mode=context.ASYNC)
@test_util.disable_tfrt('Graph is not supported yet.')
def test_apply_with_defun(self):
self._apply(defun=True)
@test_util.disable_tfrt('Graph is not supported yet.')
def test_apply_with_defun_async(self):
self._apply(defun=True, execution_mode=context.ASYNC)
@test_util.disable_tfrt('b/154858769')
def test_apply_no_top(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False)
@ -125,6 +131,7 @@ class ResNet50Test(tf.test.TestCase):
if data_format == 'channels_first' else (2, 1, 1, 2048))
self.assertEqual(output_shape, output.shape)
@test_util.disable_tfrt('b/154858769')
def test_apply_with_pooling(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
@ -133,6 +140,7 @@ class ResNet50Test(tf.test.TestCase):
output = model(images, training=False)
self.assertEqual((2, 2048), output.shape)
@test_util.disable_tfrt('b/154858769')
def test_apply_no_average_pooling(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(
@ -144,6 +152,7 @@ class ResNet50Test(tf.test.TestCase):
(2, 7, 7, 2048))
self.assertEqual(output_shape, output.shape)
@test_util.disable_tfrt('b/154858769')
def test_apply_block3_strides(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(
@ -155,6 +164,7 @@ class ResNet50Test(tf.test.TestCase):
(2, 1, 1, 2048))
self.assertEqual(output_shape, output.shape)
@test_util.disable_tfrt('b/154858769')
def test_apply_retrieve_intermediates(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(
@ -209,12 +219,15 @@ class ResNet50Test(tf.test.TestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'loss')
@test_util.disable_tfrt('b/154858769')
def test_train(self):
self._test_train()
@test_util.disable_tfrt('b/154858769')
def test_train_async(self):
self._test_train(execution_mode=context.ASYNC)
@test_util.disable_tfrt('b/154858769')
def test_no_garbage(self):
device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format)
@ -318,9 +331,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
def benchmark_eager_apply_async(self):
self._benchmark_eager_apply(
'eager_apply_async', resnet50_test_util.device_and_data_format(),
defun=False, execution_mode=context.ASYNC)
'eager_apply_async',
resnet50_test_util.device_and_data_format(),
defun=False,
execution_mode=context.ASYNC)
@test_util.disable_tfrt('Graph is not supported yet.')
def benchmark_eager_apply_with_defun(self):
self._benchmark_eager_apply(
'eager_apply_with_defun',
@ -380,6 +396,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
defun=False,
execution_mode=context.ASYNC)
@test_util.disable_tfrt('Graph is not supported yet.')
def benchmark_eager_train_with_defun(self):
self._benchmark_eager_train(
'eager_train_with_defun', MockIterator,
@ -393,9 +410,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
return iter(ds)
self._benchmark_eager_train(
'eager_train_dataset', make_iterator,
resnet50_test_util.device_and_data_format(), defun=False)
'eager_train_dataset',
make_iterator,
resnet50_test_util.device_and_data_format(),
defun=False)
@test_util.disable_tfrt('Graph is not supported yet.')
def benchmark_eager_train_datasets_with_defun(self):
def make_iterator(tensors):