Fix resnet50_test.py to be compatible with TF 2.
PiperOrigin-RevId: 298948797 Change-Id: I422c1e3ae2758ee2b4eafcd1db865f77505885ea
This commit is contained in:
parent
763cd9836c
commit
2ee97ae55a
@ -286,8 +286,8 @@ class ResNet50(tf.keras.Model):
|
|||||||
if pooling == 'avg':
|
if pooling == 'avg':
|
||||||
self.global_pooling = functools.partial(
|
self.global_pooling = functools.partial(
|
||||||
tf.reduce_mean,
|
tf.reduce_mean,
|
||||||
reduction_indices=reduction_indices,
|
axis=reduction_indices,
|
||||||
keep_dims=False)
|
keepdims=False)
|
||||||
elif pooling == 'max':
|
elif pooling == 'max':
|
||||||
self.global_pooling = functools.partial(
|
self.global_pooling = functools.partial(
|
||||||
tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False)
|
tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False)
|
||||||
|
@ -63,10 +63,10 @@ def _events_from_file(filepath):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of all tf.compat.v1.Event protos in the event file.
|
A list of all tf.compat.v1.Event protos in the event file.
|
||||||
"""
|
"""
|
||||||
records = list(tf.python_io.tf_record_iterator(filepath))
|
records = list(tf.compat.v1.python_io.tf_record_iterator(filepath))
|
||||||
result = []
|
result = []
|
||||||
for r in records:
|
for r in records:
|
||||||
event = tf.Event()
|
event = tf.compat.v1.Event()
|
||||||
event.ParseFromString(r)
|
event.ParseFromString(r)
|
||||||
result.append(event)
|
result.append(event)
|
||||||
return result
|
return result
|
||||||
@ -193,13 +193,13 @@ class ResNet50Test(tf.test.TestCase):
|
|||||||
device, data_format = resnet50_test_util.device_and_data_format()
|
device, data_format = resnet50_test_util.device_and_data_format()
|
||||||
model = resnet50.ResNet50(data_format)
|
model = resnet50.ResNet50(data_format)
|
||||||
tf.compat.v2.summary.experimental.set_step(
|
tf.compat.v2.summary.experimental.set_step(
|
||||||
tf.train.get_or_create_global_step())
|
tf.compat.v1.train.get_or_create_global_step())
|
||||||
logdir = tempfile.mkdtemp()
|
logdir = tempfile.mkdtemp()
|
||||||
with tf.compat.v2.summary.create_file_writer(
|
with tf.compat.v2.summary.create_file_writer(
|
||||||
logdir, max_queue=0,
|
logdir, max_queue=0,
|
||||||
name='t0').as_default(), tf.compat.v2.summary.record_if(True):
|
name='t0').as_default(), tf.compat.v2.summary.record_if(True):
|
||||||
with tf.device(device), context.execution_mode(execution_mode):
|
with tf.device(device), context.execution_mode(execution_mode):
|
||||||
optimizer = tf.train.GradientDescentOptimizer(0.1)
|
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
|
||||||
images, labels = resnet50_test_util.random_batch(2, data_format)
|
images, labels = resnet50_test_util.random_batch(2, data_format)
|
||||||
apply_gradients(model, optimizer,
|
apply_gradients(model, optimizer,
|
||||||
compute_gradients(model, images, labels))
|
compute_gradients(model, images, labels))
|
||||||
@ -218,7 +218,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||||||
def test_no_garbage(self):
|
def test_no_garbage(self):
|
||||||
device, data_format = resnet50_test_util.device_and_data_format()
|
device, data_format = resnet50_test_util.device_and_data_format()
|
||||||
model = resnet50.ResNet50(data_format)
|
model = resnet50.ResNet50(data_format)
|
||||||
optimizer = tf.train.GradientDescentOptimizer(0.1)
|
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
|
||||||
with tf.device(device):
|
with tf.device(device):
|
||||||
images, labels = resnet50_test_util.random_batch(2, data_format)
|
images, labels = resnet50_test_util.random_batch(2, data_format)
|
||||||
gc.disable()
|
gc.disable()
|
||||||
@ -338,7 +338,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
|||||||
(images, labels) = resnet50_test_util.random_batch(
|
(images, labels) = resnet50_test_util.random_batch(
|
||||||
batch_size, data_format)
|
batch_size, data_format)
|
||||||
model = resnet50.ResNet50(data_format)
|
model = resnet50.ResNet50(data_format)
|
||||||
optimizer = tf.train.GradientDescentOptimizer(0.1)
|
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
|
||||||
apply_grads = apply_gradients
|
apply_grads = apply_gradients
|
||||||
if defun:
|
if defun:
|
||||||
model.call = tf.function(model.call)
|
model.call = tf.function(model.call)
|
||||||
@ -409,5 +409,5 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.enable_eager_execution()
|
tf.compat.v1.enable_eager_execution()
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user