Enable resnet50 tfrt target. The actual tests are largely disabled.

PiperOrigin-RevId: 308748769
Change-Id: I65a2795bca7fcec898511a6d53d46acdc0d3e75d
This commit is contained in:
Kibeom Kim 2020-04-27 19:58:19 -07:00 committed by TensorFlower Gardener
parent 75f58fcdde
commit 4f0dd967b1
2 changed files with 25 additions and 4 deletions

View File

@ -46,6 +46,7 @@ cuda_py_test(
"oss_serial", "oss_serial",
"v1only", "v1only",
], ],
tfrt_enabled = True,
deps = [ deps = [
":resnet50", ":resnet50",
":resnet50_test_util", ":resnet50_test_util",

View File

@ -31,6 +31,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.eager.benchmarks.resnet50 import resnet50 from tensorflow.python.eager.benchmarks.resnet50 import resnet50
from tensorflow.python.eager.benchmarks.resnet50 import resnet50_test_util from tensorflow.python.eager.benchmarks.resnet50 import resnet50_test_util
from tensorflow.python.framework import test_util
def compute_gradients(model, images, labels, num_replicas=1): def compute_gradients(model, images, labels, num_replicas=1):
@ -103,18 +104,23 @@ class ResNet50Test(tf.test.TestCase):
context.async_wait() context.async_wait()
self.assertEqual((2, 1000), output.shape) self.assertEqual((2, 1000), output.shape)
@test_util.disable_tfrt('b/154858769')
def test_apply(self): def test_apply(self):
self._apply(defun=False) self._apply(defun=False)
@test_util.disable_tfrt('b/154858769')
def test_apply_async(self): def test_apply_async(self):
self._apply(defun=False, execution_mode=context.ASYNC) self._apply(defun=False, execution_mode=context.ASYNC)
@test_util.disable_tfrt('Graph is not supported yet.')
def test_apply_with_defun(self): def test_apply_with_defun(self):
self._apply(defun=True) self._apply(defun=True)
@test_util.disable_tfrt('Graph is not supported yet.')
def test_apply_with_defun_async(self): def test_apply_with_defun_async(self):
self._apply(defun=True, execution_mode=context.ASYNC) self._apply(defun=True, execution_mode=context.ASYNC)
@test_util.disable_tfrt('b/154858769')
def test_apply_no_top(self): def test_apply_no_top(self):
device, data_format = resnet50_test_util.device_and_data_format() device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False) model = resnet50.ResNet50(data_format, include_top=False)
@ -125,6 +131,7 @@ class ResNet50Test(tf.test.TestCase):
if data_format == 'channels_first' else (2, 1, 1, 2048)) if data_format == 'channels_first' else (2, 1, 1, 2048))
self.assertEqual(output_shape, output.shape) self.assertEqual(output_shape, output.shape)
@test_util.disable_tfrt('b/154858769')
def test_apply_with_pooling(self): def test_apply_with_pooling(self):
device, data_format = resnet50_test_util.device_and_data_format() device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
@ -133,6 +140,7 @@ class ResNet50Test(tf.test.TestCase):
output = model(images, training=False) output = model(images, training=False)
self.assertEqual((2, 2048), output.shape) self.assertEqual((2, 2048), output.shape)
@test_util.disable_tfrt('b/154858769')
def test_apply_no_average_pooling(self): def test_apply_no_average_pooling(self):
device, data_format = resnet50_test_util.device_and_data_format() device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50( model = resnet50.ResNet50(
@ -144,6 +152,7 @@ class ResNet50Test(tf.test.TestCase):
(2, 7, 7, 2048)) (2, 7, 7, 2048))
self.assertEqual(output_shape, output.shape) self.assertEqual(output_shape, output.shape)
@test_util.disable_tfrt('b/154858769')
def test_apply_block3_strides(self): def test_apply_block3_strides(self):
device, data_format = resnet50_test_util.device_and_data_format() device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50( model = resnet50.ResNet50(
@ -155,6 +164,7 @@ class ResNet50Test(tf.test.TestCase):
(2, 1, 1, 2048)) (2, 1, 1, 2048))
self.assertEqual(output_shape, output.shape) self.assertEqual(output_shape, output.shape)
@test_util.disable_tfrt('b/154858769')
def test_apply_retrieve_intermediates(self): def test_apply_retrieve_intermediates(self):
device, data_format = resnet50_test_util.device_and_data_format() device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50( model = resnet50.ResNet50(
@ -209,12 +219,15 @@ class ResNet50Test(tf.test.TestCase):
self.assertEqual(len(events), 2) self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'loss') self.assertEqual(events[1].summary.value[0].tag, 'loss')
@test_util.disable_tfrt('b/154858769')
def test_train(self): def test_train(self):
self._test_train() self._test_train()
@test_util.disable_tfrt('b/154858769')
def test_train_async(self): def test_train_async(self):
self._test_train(execution_mode=context.ASYNC) self._test_train(execution_mode=context.ASYNC)
@test_util.disable_tfrt('b/154858769')
def test_no_garbage(self): def test_no_garbage(self):
device, data_format = resnet50_test_util.device_and_data_format() device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format) model = resnet50.ResNet50(data_format)
@ -318,9 +331,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
def benchmark_eager_apply_async(self): def benchmark_eager_apply_async(self):
self._benchmark_eager_apply( self._benchmark_eager_apply(
'eager_apply_async', resnet50_test_util.device_and_data_format(), 'eager_apply_async',
defun=False, execution_mode=context.ASYNC) resnet50_test_util.device_and_data_format(),
defun=False,
execution_mode=context.ASYNC)
@test_util.disable_tfrt('Graph is not supported yet.')
def benchmark_eager_apply_with_defun(self): def benchmark_eager_apply_with_defun(self):
self._benchmark_eager_apply( self._benchmark_eager_apply(
'eager_apply_with_defun', 'eager_apply_with_defun',
@ -380,6 +396,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
defun=False, defun=False,
execution_mode=context.ASYNC) execution_mode=context.ASYNC)
@test_util.disable_tfrt('Graph is not supported yet.')
def benchmark_eager_train_with_defun(self): def benchmark_eager_train_with_defun(self):
self._benchmark_eager_train( self._benchmark_eager_train(
'eager_train_with_defun', MockIterator, 'eager_train_with_defun', MockIterator,
@ -393,9 +410,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
return iter(ds) return iter(ds)
self._benchmark_eager_train( self._benchmark_eager_train(
'eager_train_dataset', make_iterator, 'eager_train_dataset',
resnet50_test_util.device_and_data_format(), defun=False) make_iterator,
resnet50_test_util.device_and_data_format(),
defun=False)
@test_util.disable_tfrt('Graph is not supported yet.')
def benchmark_eager_train_datasets_with_defun(self): def benchmark_eager_train_datasets_with_defun(self):
def make_iterator(tensors): def make_iterator(tensors):