When not necessary, avoid the creation of a placeholder_with_default
in BN (not yet supported by TPU compilation).
PiperOrigin-RevId: 192502020
This commit is contained in:
parent
371d5132a5
commit
744a5cc092
@ -489,6 +489,7 @@ class BatchNormalization(Layer):
|
|||||||
return (r, d, new_mean, new_variance)
|
return (r, d, new_mean, new_variance)
|
||||||
|
|
||||||
def call(self, inputs, training=None):
|
def call(self, inputs, training=None):
|
||||||
|
original_training_value = training
|
||||||
if training is None:
|
if training is None:
|
||||||
training = K.learning_phase()
|
training = K.learning_phase()
|
||||||
|
|
||||||
@ -512,7 +513,7 @@ class BatchNormalization(Layer):
|
|||||||
# Currently never reaches here since fused_batch_norm does not support
|
# Currently never reaches here since fused_batch_norm does not support
|
||||||
# virtual batching
|
# virtual batching
|
||||||
outputs = undo_virtual_batching(outputs)
|
outputs = undo_virtual_batching(outputs)
|
||||||
if not context.executing_eagerly() and training is K.learning_phase():
|
if not context.executing_eagerly() and original_training_value is None:
|
||||||
outputs._uses_learning_phase = True # pylint: disable=protected-access
|
outputs._uses_learning_phase = True # pylint: disable=protected-access
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@ -628,7 +629,7 @@ class BatchNormalization(Layer):
|
|||||||
|
|
||||||
if self.virtual_batch_size is not None:
|
if self.virtual_batch_size is not None:
|
||||||
outputs = undo_virtual_batching(outputs)
|
outputs = undo_virtual_batching(outputs)
|
||||||
if not context.executing_eagerly() and training is K.learning_phase():
|
if not context.executing_eagerly() and original_training_value is None:
|
||||||
outputs._uses_learning_phase = True # pylint: disable=protected-access
|
outputs._uses_learning_phase = True # pylint: disable=protected-access
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user