In TF-TFRT integration, C API will get dtype from underlying fallback tensor directly if the tfrt dtype is Unsupported. This is used to support dtypes that are not natively implemented in TFRT (e.g. DT_RESOURCE).
Enable a few resnet50 tests. PiperOrigin-RevId: 312162457 Change-Id: Iece6d621120e8b20d0a0fe7b271a76dc29caa924
This commit is contained in:
parent
756e66db61
commit
1a07ecf852
@ -104,7 +104,6 @@ class ResNet50Test(tf.test.TestCase):
|
||||
context.async_wait()
|
||||
self.assertEqual((2, 1000), output.shape)
|
||||
|
||||
@test_util.disable_tfrt('b/155260334')
|
||||
def test_apply(self):
|
||||
self._apply(defun=False)
|
||||
|
||||
@ -121,7 +120,6 @@ class ResNet50Test(tf.test.TestCase):
|
||||
def test_apply_with_defun_async(self):
|
||||
self._apply(defun=True, execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt('b/155260334')
|
||||
def test_apply_no_top(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format, include_top=False)
|
||||
@ -132,7 +130,6 @@ class ResNet50Test(tf.test.TestCase):
|
||||
if data_format == 'channels_first' else (2, 1, 1, 2048))
|
||||
self.assertEqual(output_shape, output.shape)
|
||||
|
||||
@test_util.disable_tfrt('b/155260334')
|
||||
def test_apply_with_pooling(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
|
||||
@ -141,7 +138,6 @@ class ResNet50Test(tf.test.TestCase):
|
||||
output = model(images, training=False)
|
||||
self.assertEqual((2, 2048), output.shape)
|
||||
|
||||
@test_util.disable_tfrt('b/155260334')
|
||||
def test_apply_no_average_pooling(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(
|
||||
@ -153,7 +149,6 @@ class ResNet50Test(tf.test.TestCase):
|
||||
(2, 7, 7, 2048))
|
||||
self.assertEqual(output_shape, output.shape)
|
||||
|
||||
@test_util.disable_tfrt('b/155260334')
|
||||
def test_apply_block3_strides(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(
|
||||
@ -165,7 +160,6 @@ class ResNet50Test(tf.test.TestCase):
|
||||
(2, 1, 1, 2048))
|
||||
self.assertEqual(output_shape, output.shape)
|
||||
|
||||
@test_util.disable_tfrt('b/155260334')
|
||||
def test_apply_retrieve_intermediates(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(
|
||||
@ -220,7 +214,6 @@ class ResNet50Test(tf.test.TestCase):
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].tag, 'loss')
|
||||
|
||||
@test_util.disable_tfrt('b/155260334')
|
||||
def test_train(self):
|
||||
self._test_train()
|
||||
|
||||
@ -228,7 +221,6 @@ class ResNet50Test(tf.test.TestCase):
|
||||
def test_train_async(self):
|
||||
self._test_train(execution_mode=context.ASYNC)
|
||||
|
||||
@test_util.disable_tfrt('b/155260334')
|
||||
def test_no_garbage(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format)
|
||||
|
Loading…
Reference in New Issue
Block a user