When not necessary, avoid the creation of a placeholder_with_default
in BN (not yet supported by TPU compilation).
PiperOrigin-RevId: 192502020
This commit is contained in:
parent
371d5132a5
commit
744a5cc092
@ -489,6 +489,7 @@ class BatchNormalization(Layer):
|
||||
return (r, d, new_mean, new_variance)
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
original_training_value = training
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
|
||||
@ -512,7 +513,7 @@ class BatchNormalization(Layer):
|
||||
# Currently never reaches here since fused_batch_norm does not support
|
||||
# virtual batching
|
||||
outputs = undo_virtual_batching(outputs)
|
||||
if not context.executing_eagerly() and training is K.learning_phase():
|
||||
if not context.executing_eagerly() and original_training_value is None:
|
||||
outputs._uses_learning_phase = True # pylint: disable=protected-access
|
||||
return outputs
|
||||
|
||||
@ -628,7 +629,7 @@ class BatchNormalization(Layer):
|
||||
|
||||
if self.virtual_batch_size is not None:
|
||||
outputs = undo_virtual_batching(outputs)
|
||||
if not context.executing_eagerly() and training is K.learning_phase():
|
||||
if not context.executing_eagerly() and original_training_value is None:
|
||||
outputs._uses_learning_phase = True # pylint: disable=protected-access
|
||||
return outputs
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user