When not necessary, avoid the creation of a placeholder_with_default in BN (not yet supported by TPU compilation).

PiperOrigin-RevId: 192502020
This commit is contained in:
Francois Chollet 2018-04-11 13:46:03 -07:00 committed by TensorFlower Gardener
parent 371d5132a5
commit 744a5cc092

View File

@ -489,6 +489,7 @@ class BatchNormalization(Layer):
return (r, d, new_mean, new_variance)
def call(self, inputs, training=None):
original_training_value = training
if training is None:
training = K.learning_phase()
@ -512,7 +513,7 @@ class BatchNormalization(Layer):
# Currently never reaches here since fused_batch_norm does not support
# virtual batching
outputs = undo_virtual_batching(outputs)
if not context.executing_eagerly() and training is K.learning_phase():
if not context.executing_eagerly() and original_training_value is None:
outputs._uses_learning_phase = True # pylint: disable=protected-access
return outputs
@ -628,7 +629,7 @@ class BatchNormalization(Layer):
if self.virtual_batch_size is not None:
outputs = undo_virtual_batching(outputs)
if not context.executing_eagerly() and training is K.learning_phase():
if not context.executing_eagerly() and original_training_value is None:
outputs._uses_learning_phase = True # pylint: disable=protected-access
return outputs