Fix performance regression involving trainable check in batchnorm.
The regression made it so a tf `and` op would be used for something that just required a python check, which in turn would make an if statement build a tf.cond instead of a python if. This change makes it just use a python if. PiperOrigin-RevId: 312345759 Change-Id: I568c9c992287bfc3e693f34b7b51bd7f35388f34
This commit is contained in:
parent
82143c1ad8
commit
935c55c590
@ -712,9 +712,10 @@ class BatchNormalizationBase(Layer):
|
||||
if self._USE_V2_BEHAVIOR:
|
||||
if isinstance(training, int):
|
||||
training = bool(training)
|
||||
# When the layer is not trainable, it overrides the value passed from
|
||||
# model.
|
||||
training = math_ops.logical_and(training, self.trainable)
|
||||
if not self.trainable:
|
||||
# When the layer is not trainable, it overrides the value passed from
|
||||
# model.
|
||||
training = False
|
||||
return training
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
|
Loading…
Reference in New Issue
Block a user