Fix performance regression involving trainable check in batchnorm.
The regression made it so a tf `and` op would be used for something that just required a python check, which in turn would make an if statement build a tf.cond instead of a python if. This change makes it just use a python if. PiperOrigin-RevId: 312345759 Change-Id: I568c9c992287bfc3e693f34b7b51bd7f35388f34
This commit is contained in:
parent
82143c1ad8
commit
935c55c590
@ -712,9 +712,10 @@ class BatchNormalizationBase(Layer):
|
|||||||
if self._USE_V2_BEHAVIOR:
|
if self._USE_V2_BEHAVIOR:
|
||||||
if isinstance(training, int):
|
if isinstance(training, int):
|
||||||
training = bool(training)
|
training = bool(training)
|
||||||
# When the layer is not trainable, it overrides the value passed from
|
if not self.trainable:
|
||||||
# model.
|
# When the layer is not trainable, it overrides the value passed from
|
||||||
training = math_ops.logical_and(training, self.trainable)
|
# model.
|
||||||
|
training = False
|
||||||
return training
|
return training
|
||||||
|
|
||||||
def call(self, inputs, training=None):
|
def call(self, inputs, training=None):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user