Enable resnet50 tfrt target. The actual tests are largely disabled.
PiperOrigin-RevId: 308748769 Change-Id: I65a2795bca7fcec898511a6d53d46acdc0d3e75d
This commit is contained in:
parent
75f58fcdde
commit
4f0dd967b1
@ -46,6 +46,7 @@ cuda_py_test(
|
|||||||
"oss_serial",
|
"oss_serial",
|
||||||
"v1only",
|
"v1only",
|
||||||
],
|
],
|
||||||
|
tfrt_enabled = True,
|
||||||
deps = [
|
deps = [
|
||||||
":resnet50",
|
":resnet50",
|
||||||
":resnet50_test_util",
|
":resnet50_test_util",
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.eager.benchmarks.resnet50 import resnet50
|
from tensorflow.python.eager.benchmarks.resnet50 import resnet50
|
||||||
from tensorflow.python.eager.benchmarks.resnet50 import resnet50_test_util
|
from tensorflow.python.eager.benchmarks.resnet50 import resnet50_test_util
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
|
||||||
|
|
||||||
def compute_gradients(model, images, labels, num_replicas=1):
|
def compute_gradients(model, images, labels, num_replicas=1):
|
||||||
@ -103,18 +104,23 @@ class ResNet50Test(tf.test.TestCase):
|
|||||||
context.async_wait()
|
context.async_wait()
|
||||||
self.assertEqual((2, 1000), output.shape)
|
self.assertEqual((2, 1000), output.shape)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/154858769')
|
||||||
def test_apply(self):
|
def test_apply(self):
|
||||||
self._apply(defun=False)
|
self._apply(defun=False)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/154858769')
|
||||||
def test_apply_async(self):
|
def test_apply_async(self):
|
||||||
self._apply(defun=False, execution_mode=context.ASYNC)
|
self._apply(defun=False, execution_mode=context.ASYNC)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('Graph is not supported yet.')
|
||||||
def test_apply_with_defun(self):
|
def test_apply_with_defun(self):
|
||||||
self._apply(defun=True)
|
self._apply(defun=True)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('Graph is not supported yet.')
|
||||||
def test_apply_with_defun_async(self):
|
def test_apply_with_defun_async(self):
|
||||||
self._apply(defun=True, execution_mode=context.ASYNC)
|
self._apply(defun=True, execution_mode=context.ASYNC)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/154858769')
|
||||||
def test_apply_no_top(self):
|
def test_apply_no_top(self):
|
||||||
device, data_format = resnet50_test_util.device_and_data_format()
|
device, data_format = resnet50_test_util.device_and_data_format()
|
||||||
model = resnet50.ResNet50(data_format, include_top=False)
|
model = resnet50.ResNet50(data_format, include_top=False)
|
||||||
@ -125,6 +131,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||||||
if data_format == 'channels_first' else (2, 1, 1, 2048))
|
if data_format == 'channels_first' else (2, 1, 1, 2048))
|
||||||
self.assertEqual(output_shape, output.shape)
|
self.assertEqual(output_shape, output.shape)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/154858769')
|
||||||
def test_apply_with_pooling(self):
|
def test_apply_with_pooling(self):
|
||||||
device, data_format = resnet50_test_util.device_and_data_format()
|
device, data_format = resnet50_test_util.device_and_data_format()
|
||||||
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
|
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
|
||||||
@ -133,6 +140,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||||||
output = model(images, training=False)
|
output = model(images, training=False)
|
||||||
self.assertEqual((2, 2048), output.shape)
|
self.assertEqual((2, 2048), output.shape)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/154858769')
|
||||||
def test_apply_no_average_pooling(self):
|
def test_apply_no_average_pooling(self):
|
||||||
device, data_format = resnet50_test_util.device_and_data_format()
|
device, data_format = resnet50_test_util.device_and_data_format()
|
||||||
model = resnet50.ResNet50(
|
model = resnet50.ResNet50(
|
||||||
@ -144,6 +152,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||||||
(2, 7, 7, 2048))
|
(2, 7, 7, 2048))
|
||||||
self.assertEqual(output_shape, output.shape)
|
self.assertEqual(output_shape, output.shape)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/154858769')
|
||||||
def test_apply_block3_strides(self):
|
def test_apply_block3_strides(self):
|
||||||
device, data_format = resnet50_test_util.device_and_data_format()
|
device, data_format = resnet50_test_util.device_and_data_format()
|
||||||
model = resnet50.ResNet50(
|
model = resnet50.ResNet50(
|
||||||
@ -155,6 +164,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||||||
(2, 1, 1, 2048))
|
(2, 1, 1, 2048))
|
||||||
self.assertEqual(output_shape, output.shape)
|
self.assertEqual(output_shape, output.shape)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/154858769')
|
||||||
def test_apply_retrieve_intermediates(self):
|
def test_apply_retrieve_intermediates(self):
|
||||||
device, data_format = resnet50_test_util.device_and_data_format()
|
device, data_format = resnet50_test_util.device_and_data_format()
|
||||||
model = resnet50.ResNet50(
|
model = resnet50.ResNet50(
|
||||||
@ -209,12 +219,15 @@ class ResNet50Test(tf.test.TestCase):
|
|||||||
self.assertEqual(len(events), 2)
|
self.assertEqual(len(events), 2)
|
||||||
self.assertEqual(events[1].summary.value[0].tag, 'loss')
|
self.assertEqual(events[1].summary.value[0].tag, 'loss')
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/154858769')
|
||||||
def test_train(self):
|
def test_train(self):
|
||||||
self._test_train()
|
self._test_train()
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/154858769')
|
||||||
def test_train_async(self):
|
def test_train_async(self):
|
||||||
self._test_train(execution_mode=context.ASYNC)
|
self._test_train(execution_mode=context.ASYNC)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('b/154858769')
|
||||||
def test_no_garbage(self):
|
def test_no_garbage(self):
|
||||||
device, data_format = resnet50_test_util.device_and_data_format()
|
device, data_format = resnet50_test_util.device_and_data_format()
|
||||||
model = resnet50.ResNet50(data_format)
|
model = resnet50.ResNet50(data_format)
|
||||||
@ -318,9 +331,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
|||||||
|
|
||||||
def benchmark_eager_apply_async(self):
|
def benchmark_eager_apply_async(self):
|
||||||
self._benchmark_eager_apply(
|
self._benchmark_eager_apply(
|
||||||
'eager_apply_async', resnet50_test_util.device_and_data_format(),
|
'eager_apply_async',
|
||||||
defun=False, execution_mode=context.ASYNC)
|
resnet50_test_util.device_and_data_format(),
|
||||||
|
defun=False,
|
||||||
|
execution_mode=context.ASYNC)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('Graph is not supported yet.')
|
||||||
def benchmark_eager_apply_with_defun(self):
|
def benchmark_eager_apply_with_defun(self):
|
||||||
self._benchmark_eager_apply(
|
self._benchmark_eager_apply(
|
||||||
'eager_apply_with_defun',
|
'eager_apply_with_defun',
|
||||||
@ -380,6 +396,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
|||||||
defun=False,
|
defun=False,
|
||||||
execution_mode=context.ASYNC)
|
execution_mode=context.ASYNC)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('Graph is not supported yet.')
|
||||||
def benchmark_eager_train_with_defun(self):
|
def benchmark_eager_train_with_defun(self):
|
||||||
self._benchmark_eager_train(
|
self._benchmark_eager_train(
|
||||||
'eager_train_with_defun', MockIterator,
|
'eager_train_with_defun', MockIterator,
|
||||||
@ -393,9 +410,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
|||||||
return iter(ds)
|
return iter(ds)
|
||||||
|
|
||||||
self._benchmark_eager_train(
|
self._benchmark_eager_train(
|
||||||
'eager_train_dataset', make_iterator,
|
'eager_train_dataset',
|
||||||
resnet50_test_util.device_and_data_format(), defun=False)
|
make_iterator,
|
||||||
|
resnet50_test_util.device_and_data_format(),
|
||||||
|
defun=False)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('Graph is not supported yet.')
|
||||||
def benchmark_eager_train_datasets_with_defun(self):
|
def benchmark_eager_train_datasets_with_defun(self):
|
||||||
|
|
||||||
def make_iterator(tensors):
|
def make_iterator(tensors):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user