Use the new V2 optimizer in resnet50 benchmark.

PiperOrigin-RevId: 313687026
Change-Id: If51837cddc1a7e4d2ef70c064788a8f4a7728a6a
This commit is contained in:
Xiao Yu 2020-05-28 17:00:12 -07:00 committed by TensorFlower Gardener
parent 547daed259
commit 347fe6ece4

View File

@ -355,7 +355,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
(images, labels) = resnet50_test_util.random_batch(
batch_size, data_format)
model = resnet50.ResNet50(data_format)
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
optimizer = tf.keras.optimizers.SGD(0.1)
apply_grads = apply_gradients
if defun:
model.call = tf.function(model.call)