Fix resnet50_test.py to be compatible with TF 2.
PiperOrigin-RevId: 298948797 Change-Id: I422c1e3ae2758ee2b4eafcd1db865f77505885ea
This commit is contained in:
parent
763cd9836c
commit
2ee97ae55a
tensorflow/python/eager/benchmarks/resnet50
@ -286,8 +286,8 @@ class ResNet50(tf.keras.Model):
|
||||
if pooling == 'avg':
|
||||
self.global_pooling = functools.partial(
|
||||
tf.reduce_mean,
|
||||
reduction_indices=reduction_indices,
|
||||
keep_dims=False)
|
||||
axis=reduction_indices,
|
||||
keepdims=False)
|
||||
elif pooling == 'max':
|
||||
self.global_pooling = functools.partial(
|
||||
tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False)
|
||||
|
@ -63,10 +63,10 @@ def _events_from_file(filepath):
|
||||
Returns:
|
||||
A list of all tf.compat.v1.Event protos in the event file.
|
||||
"""
|
||||
records = list(tf.python_io.tf_record_iterator(filepath))
|
||||
records = list(tf.compat.v1.python_io.tf_record_iterator(filepath))
|
||||
result = []
|
||||
for r in records:
|
||||
event = tf.Event()
|
||||
event = tf.compat.v1.Event()
|
||||
event.ParseFromString(r)
|
||||
result.append(event)
|
||||
return result
|
||||
@ -193,13 +193,13 @@ class ResNet50Test(tf.test.TestCase):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format)
|
||||
tf.compat.v2.summary.experimental.set_step(
|
||||
tf.train.get_or_create_global_step())
|
||||
tf.compat.v1.train.get_or_create_global_step())
|
||||
logdir = tempfile.mkdtemp()
|
||||
with tf.compat.v2.summary.create_file_writer(
|
||||
logdir, max_queue=0,
|
||||
name='t0').as_default(), tf.compat.v2.summary.record_if(True):
|
||||
with tf.device(device), context.execution_mode(execution_mode):
|
||||
optimizer = tf.train.GradientDescentOptimizer(0.1)
|
||||
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
|
||||
images, labels = resnet50_test_util.random_batch(2, data_format)
|
||||
apply_gradients(model, optimizer,
|
||||
compute_gradients(model, images, labels))
|
||||
@ -218,7 +218,7 @@ class ResNet50Test(tf.test.TestCase):
|
||||
def test_no_garbage(self):
|
||||
device, data_format = resnet50_test_util.device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format)
|
||||
optimizer = tf.train.GradientDescentOptimizer(0.1)
|
||||
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
|
||||
with tf.device(device):
|
||||
images, labels = resnet50_test_util.random_batch(2, data_format)
|
||||
gc.disable()
|
||||
@ -338,7 +338,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
||||
(images, labels) = resnet50_test_util.random_batch(
|
||||
batch_size, data_format)
|
||||
model = resnet50.ResNet50(data_format)
|
||||
optimizer = tf.train.GradientDescentOptimizer(0.1)
|
||||
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
|
||||
apply_grads = apply_gradients
|
||||
if defun:
|
||||
model.call = tf.function(model.call)
|
||||
@ -409,5 +409,5 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.enable_eager_execution()
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
tf.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user