From 8ed969c39d6179364c79c2564922350a3e8d62b1 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 24 Mar 2020 12:14:00 -0700 Subject: [PATCH] Remove the optimizer v2 instance check, and replace with v1 object check. Optimizer_v2 instance doesn't share same base class with optimizer v1. PiperOrigin-RevId: 302719205 Change-Id: I225c891e59e93d3b5386fa3682fb1bdc97e770de --- tensorflow/python/tpu/tpu_optimizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/tpu/tpu_optimizer.py b/tensorflow/python/tpu/tpu_optimizer.py index e233bbb007c..c6c710c9ccc 100644 --- a/tensorflow/python/tpu/tpu_optimizer.py +++ b/tensorflow/python/tpu/tpu_optimizer.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import tpu_function @@ -55,10 +54,11 @@ class CrossShardOptimizer(optimizer.Optimizer): """ if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN): raise ValueError("Unsupported reduction: %s." % reduction) - if isinstance(opt, optimizer_v2.OptimizerV2): + if not isinstance(opt, optimizer.Optimizer): raise TypeError( - "CrossShardOptimizer does not work with OptimizerV2. If you are " - "using TPUStrategy, OptimizerV2 will sum gradients across replicas." + "CrossShardOptimizer only works with tf.training.Optimizer and not " + "Optimizer_v2. If you are using TPUStrategy, OptimizerV2 will sum " + "gradients across replicas." "If you are using TPUEstimator, you may instead sum your gradients " "with: grads = [tf.compat.v1.tpu.cross_replica_sum(g) for g in grads]" ". If you want to average your gradients, rescale your loss with: "