Enable resnet50 tfrt target. The actual tests are largely disabled.
PiperOrigin-RevId: 308748769 Change-Id: I65a2795bca7fcec898511a6d53d46acdc0d3e75d
This commit is contained in:
parent
75f58fcdde
commit
4f0dd967b1
@ -46,6 +46,7 @@ cuda_py_test(
|
||||
"oss_serial",
|
||||
"v1only",
|
||||
],
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
":resnet50",
|
||||
":resnet50_test_util",
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager.benchmarks.resnet50 import resnet50
|
||||
from tensorflow.python.eager.benchmarks.resnet50 import resnet50_test_util
|
||||
from tensorflow.python.framework import test_util
|
||||
|
||||
|
||||
def compute_gradients(model, images, labels, num_replicas=1):
|
||||
@ -103,18 +104,23 @@ class ResNet50Test(tf.test.TestCase):
|
||||
context.async_wait()
|
||||
self.assertEqual((2, 1000), output.shape)
|
||||
|
||||
@test_util.disable_tfrt('b/154858769')
|
||||
def test_apply(self):
|
||||
self._apply(defun=False)
|
||||
|
||||
@test_util.disable_tfrt('b/154858769')
|
||||
def test_apply_async(self):
|
||||
self._apply(defun=False, execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt('Graph is not supported yet.')
|
||||
def test_apply_with_defun(self):
|
||||
self._apply(defun=True)
|
||||
|
||||
@test_util.disable_tfrt('Graph is not supported yet.')
|
||||
def test_apply_with_defun_async(self):
|
||||
self._apply(defun=True, execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt('b/154858769')
|
||||
def test_apply_no_top(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format, include_top=False)
|
||||
@ -125,6 +131,7 @@ class ResNet50Test(tf.test.TestCase):
|
||||
if data_format == 'channels_first' else (2, 1, 1, 2048))
|
||||
self.assertEqual(output_shape, output.shape)
|
||||
|
||||
@test_util.disable_tfrt('b/154858769')
|
||||
def test_apply_with_pooling(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
|
||||
@ -133,6 +140,7 @@ class ResNet50Test(tf.test.TestCase):
|
||||
output = model(images, training=False)
|
||||
self.assertEqual((2, 2048), output.shape)
|
||||
|
||||
@test_util.disable_tfrt('b/154858769')
|
||||
def test_apply_no_average_pooling(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(
|
||||
@ -144,6 +152,7 @@ class ResNet50Test(tf.test.TestCase):
|
||||
(2, 7, 7, 2048))
|
||||
self.assertEqual(output_shape, output.shape)
|
||||
|
||||
@test_util.disable_tfrt('b/154858769')
|
||||
def test_apply_block3_strides(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(
|
||||
@ -155,6 +164,7 @@ class ResNet50Test(tf.test.TestCase):
|
||||
(2, 1, 1, 2048))
|
||||
self.assertEqual(output_shape, output.shape)
|
||||
|
||||
@test_util.disable_tfrt('b/154858769')
|
||||
def test_apply_retrieve_intermediates(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(
|
||||
@ -209,12 +219,15 @@ class ResNet50Test(tf.test.TestCase):
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].tag, 'loss')
|
||||
|
||||
@test_util.disable_tfrt('b/154858769')
|
||||
def test_train(self):
|
||||
self._test_train()
|
||||
|
||||
@test_util.disable_tfrt('b/154858769')
|
||||
def test_train_async(self):
|
||||
self._test_train(execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt('b/154858769')
|
||||
def test_no_garbage(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format)
|
||||
@ -318,9 +331,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
||||
|
||||
def benchmark_eager_apply_async(self):
|
||||
self._benchmark_eager_apply(
|
||||
'eager_apply_async', resnet50_test_util.device_and_data_format(),
|
||||
defun=False, execution_mode=context.ASYNC)
|
||||
'eager_apply_async',
|
||||
resnet50_test_util.device_and_data_format(),
|
||||
defun=False,
|
||||
execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt('Graph is not supported yet.')
|
||||
def benchmark_eager_apply_with_defun(self):
|
||||
self._benchmark_eager_apply(
|
||||
'eager_apply_with_defun',
|
||||
@ -380,6 +396,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
||||
defun=False,
|
||||
execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt('Graph is not supported yet.')
|
||||
def benchmark_eager_train_with_defun(self):
|
||||
self._benchmark_eager_train(
|
||||
'eager_train_with_defun', MockIterator,
|
||||
@ -393,9 +410,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
||||
return iter(ds)
|
||||
|
||||
self._benchmark_eager_train(
|
||||
'eager_train_dataset', make_iterator,
|
||||
resnet50_test_util.device_and_data_format(), defun=False)
|
||||
'eager_train_dataset',
|
||||
make_iterator,
|
||||
resnet50_test_util.device_and_data_format(),
|
||||
defun=False)
|
||||
|
||||
@test_util.disable_tfrt('Graph is not supported yet.')
|
||||
def benchmark_eager_train_datasets_with_defun(self):
|
||||
|
||||
def make_iterator(tensors):
|
||||
|
Loading…
Reference in New Issue
Block a user