Eliminate tf.where call in updating running mean and average for batch norm in Keras. This unlocks the improvements promised by fusing updates into the CPU and GPU kernels.
Results for ResNet50 /w batch size 32 in eager mode on GTX 1080: Before: 85 images/s (fp32), 81 images/s (fp16) After: 101 images/s (fp32), 99 images/s (fp16) PiperOrigin-RevId: 298927568 Change-Id: I2941bff5d19c7fdccad78bbb9c5df2fdcd2fc36a
This commit is contained in:
parent
8f399978bd
commit
f7d4c7ffd5
@ -514,11 +514,9 @@ class BatchNormalizationBase(Layer):
|
||||
K.zeros_like(update_delta))
|
||||
return state_ops.assign_sub(variable, update_delta, name=scope)
|
||||
|
||||
def _assign_new_value(self, variable, value, inputs_size=None):
|
||||
def _assign_new_value(self, variable, value):
|
||||
with K.name_scope('AssignNewValue') as scope:
|
||||
with ops.colocate_with(variable):
|
||||
if inputs_size is not None:
|
||||
value = array_ops.where(inputs_size > 0, value, variable)
|
||||
return state_ops.assign(variable, value, name=scope)
|
||||
|
||||
def _fused_batch_norm(self, inputs, training):
|
||||
@ -569,6 +567,9 @@ class BatchNormalizationBase(Layer):
|
||||
data_format=self._data_format,
|
||||
exponential_avg_factor=exponential_avg_factor)
|
||||
|
||||
def _fused_batch_norm_training_empty():
|
||||
return inputs, self.moving_mean, self.moving_variance
|
||||
|
||||
def _fused_batch_norm_inference():
|
||||
return nn.fused_batch_norm(
|
||||
inputs,
|
||||
@ -580,8 +581,14 @@ class BatchNormalizationBase(Layer):
|
||||
is_training=False,
|
||||
data_format=self._data_format)
|
||||
|
||||
output, mean, variance = tf_utils.smart_cond(
|
||||
training, _fused_batch_norm_training, _fused_batch_norm_inference)
|
||||
train_op = _fused_batch_norm_training
|
||||
if compat.forward_compatible(2020, 3, 6) and inputs_size is not None:
|
||||
train_op = lambda: tf_utils.smart_cond(inputs_size > 0,
|
||||
_fused_batch_norm_training,
|
||||
_fused_batch_norm_training_empty)
|
||||
|
||||
output, mean, variance = tf_utils.smart_cond(training, train_op,
|
||||
_fused_batch_norm_inference)
|
||||
variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)
|
||||
|
||||
training_value = tf_utils.constant_value(training)
|
||||
@ -596,7 +603,7 @@ class BatchNormalizationBase(Layer):
|
||||
def mean_update():
|
||||
"""Update self.moving_mean with the most recent data point."""
|
||||
if compat.forward_compatible(2020, 3, 6):
|
||||
return self._assign_new_value(self.moving_mean, mean, inputs_size)
|
||||
return self._assign_new_value(self.moving_mean, mean)
|
||||
else:
|
||||
return self._assign_moving_average(self.moving_mean, mean, momentum,
|
||||
inputs_size)
|
||||
@ -604,8 +611,7 @@ class BatchNormalizationBase(Layer):
|
||||
def variance_update():
|
||||
"""Update self.moving_variance with the most recent data point."""
|
||||
if compat.forward_compatible(2020, 3, 6):
|
||||
return self._assign_new_value(self.moving_variance, variance,
|
||||
inputs_size)
|
||||
return self._assign_new_value(self.moving_variance, variance)
|
||||
else:
|
||||
return self._assign_moving_average(self.moving_variance, variance,
|
||||
momentum, inputs_size)
|
||||
|
Loading…
x
Reference in New Issue
Block a user