Merge pull request #39131 from yongtang:32380-BatchNormalization-virtual_batch_size-None
PiperOrigin-RevId: 309995310 Change-Id: I2d86a93834cabcc80f0f9a28d87985e9aa7b6c9c
This commit is contained in:
commit
9dce54f00b
@ -736,8 +736,14 @@ class BatchNormalizationBase(Layer):
|
|||||||
if self.virtual_batch_size is not None:
|
if self.virtual_batch_size is not None:
|
||||||
# Virtual batches (aka ghost batches) can be simulated by reshaping the
|
# Virtual batches (aka ghost batches) can be simulated by reshaping the
|
||||||
# Tensor and reusing the existing batch norm implementation
|
# Tensor and reusing the existing batch norm implementation
|
||||||
original_shape = [-1] + inputs.shape.as_list()[1:]
|
original_shape = array_ops.shape(inputs)
|
||||||
expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]
|
original_shape = array_ops.concat(
|
||||||
|
[constant_op.constant([-1]), original_shape[1:]], axis=0)
|
||||||
|
expanded_shape = array_ops.concat([
|
||||||
|
constant_op.constant([self.virtual_batch_size, -1]),
|
||||||
|
original_shape[1:]
|
||||||
|
],
|
||||||
|
axis=0)
|
||||||
|
|
||||||
# Will cause errors if virtual_batch_size does not divide the batch size
|
# Will cause errors if virtual_batch_size does not divide the batch size
|
||||||
inputs = array_ops.reshape(inputs, expanded_shape)
|
inputs = array_ops.reshape(inputs, expanded_shape)
|
||||||
|
@ -354,6 +354,13 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
|
|||||||
# Updates should be tracked in a `wrap_function`.
|
# Updates should be tracked in a `wrap_function`.
|
||||||
self.assertLen(layer.updates, 2)
|
self.assertLen(layer.updates, 2)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self):
|
||||||
|
# Test case for GitHub issue for 32380
|
||||||
|
norm = normalization_v2.BatchNormalization(virtual_batch_size=8)
|
||||||
|
inp = keras.layers.Input(shape=(None, None, 3))
|
||||||
|
_ = norm(inp)
|
||||||
|
|
||||||
|
|
||||||
def _run_batchnorm_correctness_test(layer, dtype='float32', fused=False):
|
def _run_batchnorm_correctness_test(layer, dtype='float32', fused=False):
|
||||||
model = keras.models.Sequential()
|
model = keras.models.Sequential()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user