Fix resnet50_test.py to be compatible with TF 2.

PiperOrigin-RevId: 298948797
Change-Id: I422c1e3ae2758ee2b4eafcd1db865f77505885ea
This commit is contained in:
A. Unique TensorFlower 2020-03-04 14:58:41 -08:00 committed by TensorFlower Gardener
parent 763cd9836c
commit 2ee97ae55a
2 changed files with 9 additions and 9 deletions

View File

@ -286,8 +286,8 @@ class ResNet50(tf.keras.Model):
if pooling == 'avg': if pooling == 'avg':
self.global_pooling = functools.partial( self.global_pooling = functools.partial(
tf.reduce_mean, tf.reduce_mean,
reduction_indices=reduction_indices, axis=reduction_indices,
keep_dims=False) keepdims=False)
elif pooling == 'max': elif pooling == 'max':
self.global_pooling = functools.partial( self.global_pooling = functools.partial(
tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False) tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False)

View File

@ -63,10 +63,10 @@ def _events_from_file(filepath):
Returns: Returns:
A list of all tf.compat.v1.Event protos in the event file. A list of all tf.compat.v1.Event protos in the event file.
""" """
records = list(tf.python_io.tf_record_iterator(filepath)) records = list(tf.compat.v1.python_io.tf_record_iterator(filepath))
result = [] result = []
for r in records: for r in records:
event = tf.Event() event = tf.compat.v1.Event()
event.ParseFromString(r) event.ParseFromString(r)
result.append(event) result.append(event)
return result return result
@ -193,13 +193,13 @@ class ResNet50Test(tf.test.TestCase):
device, data_format = resnet50_test_util.device_and_data_format() device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format) model = resnet50.ResNet50(data_format)
tf.compat.v2.summary.experimental.set_step( tf.compat.v2.summary.experimental.set_step(
tf.train.get_or_create_global_step()) tf.compat.v1.train.get_or_create_global_step())
logdir = tempfile.mkdtemp() logdir = tempfile.mkdtemp()
with tf.compat.v2.summary.create_file_writer( with tf.compat.v2.summary.create_file_writer(
logdir, max_queue=0, logdir, max_queue=0,
name='t0').as_default(), tf.compat.v2.summary.record_if(True): name='t0').as_default(), tf.compat.v2.summary.record_if(True):
with tf.device(device), context.execution_mode(execution_mode): with tf.device(device), context.execution_mode(execution_mode):
optimizer = tf.train.GradientDescentOptimizer(0.1) optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
images, labels = resnet50_test_util.random_batch(2, data_format) images, labels = resnet50_test_util.random_batch(2, data_format)
apply_gradients(model, optimizer, apply_gradients(model, optimizer,
compute_gradients(model, images, labels)) compute_gradients(model, images, labels))
@ -218,7 +218,7 @@ class ResNet50Test(tf.test.TestCase):
def test_no_garbage(self): def test_no_garbage(self):
device, data_format = resnet50_test_util.device_and_data_format() device, data_format = resnet50_test_util.device_and_data_format()
model = resnet50.ResNet50(data_format) model = resnet50.ResNet50(data_format)
optimizer = tf.train.GradientDescentOptimizer(0.1) optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
with tf.device(device): with tf.device(device):
images, labels = resnet50_test_util.random_batch(2, data_format) images, labels = resnet50_test_util.random_batch(2, data_format)
gc.disable() gc.disable()
@ -338,7 +338,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
(images, labels) = resnet50_test_util.random_batch( (images, labels) = resnet50_test_util.random_batch(
batch_size, data_format) batch_size, data_format)
model = resnet50.ResNet50(data_format) model = resnet50.ResNet50(data_format)
optimizer = tf.train.GradientDescentOptimizer(0.1) optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
apply_grads = apply_gradients apply_grads = apply_gradients
if defun: if defun:
model.call = tf.function(model.call) model.call = tf.function(model.call)
@ -409,5 +409,5 @@ class ResNet50Benchmarks(tf.test.Benchmark):
if __name__ == '__main__': if __name__ == '__main__':
tf.enable_eager_execution() tf.compat.v1.enable_eager_execution()
tf.test.main() tf.test.main()