Fix performance regression involving trainable check in batchnorm.

The regression made it so a tf `and` op would be used for something that just required a python check, which in turn would make an if statement build a tf.cond instead of a python if.

This change makes it just use a python if.

PiperOrigin-RevId: 312345759
Change-Id: I568c9c992287bfc3e693f34b7b51bd7f35388f34
This commit is contained in:
Tomer Kaftan 2020-05-19 13:48:55 -07:00 committed by TensorFlower Gardener
parent 82143c1ad8
commit 935c55c590

View File

@ -712,9 +712,10 @@ class BatchNormalizationBase(Layer):
if self._USE_V2_BEHAVIOR:
if isinstance(training, int):
training = bool(training)
# When the layer is not trainable, it overrides the value passed from
# model.
training = math_ops.logical_and(training, self.trainable)
if not self.trainable:
# When the layer is not trainable, it overrides the value passed from
# model.
training = False
return training
def call(self, inputs, training=None):