Remove the optimizer v2 instance check, and replace with v1 object check.
Optimizer_v2 instance doesn't share same base class with optimizer v1. PiperOrigin-RevId: 302719205 Change-Id: I225c891e59e93d3b5386fa3682fb1bdc97e770de
This commit is contained in:
parent
7ae1992664
commit
8ed969c39d
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
|
||||||
from tensorflow.python.ops.losses import losses
|
from tensorflow.python.ops.losses import losses
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.tpu import tpu_function
|
from tensorflow.python.tpu import tpu_function
|
||||||
@ -55,10 +54,11 @@ class CrossShardOptimizer(optimizer.Optimizer):
|
|||||||
"""
|
"""
|
||||||
if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN):
|
if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN):
|
||||||
raise ValueError("Unsupported reduction: %s." % reduction)
|
raise ValueError("Unsupported reduction: %s." % reduction)
|
||||||
if isinstance(opt, optimizer_v2.OptimizerV2):
|
if not isinstance(opt, optimizer.Optimizer):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"CrossShardOptimizer does not work with OptimizerV2. If you are "
|
"CrossShardOptimizer only works with tf.training.Optimizer and not "
|
||||||
"using TPUStrategy, OptimizerV2 will sum gradients across replicas."
|
"Optimizer_v2. If you are using TPUStrategy, OptimizerV2 will sum "
|
||||||
|
"gradients across replicas."
|
||||||
"If you are using TPUEstimator, you may instead sum your gradients "
|
"If you are using TPUEstimator, you may instead sum your gradients "
|
||||||
"with: grads = [tf.compat.v1.tpu.cross_replica_sum(g) for g in grads]"
|
"with: grads = [tf.compat.v1.tpu.cross_replica_sum(g) for g in grads]"
|
||||||
". If you want to average your gradients, rescale your loss with: "
|
". If you want to average your gradients, rescale your loss with: "
|
||||||
|
Loading…
Reference in New Issue
Block a user