Update to use tf.shape to get the shape of the tensor, from review comment
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
d5598f1609
commit
1f44a4cd05
@ -736,10 +736,13 @@ class BatchNormalizationBase(Layer):
|
|||||||
if self.virtual_batch_size is not None:
|
if self.virtual_batch_size is not None:
|
||||||
# Virtual batches (aka ghost batches) can be simulated by reshaping the
|
# Virtual batches (aka ghost batches) can be simulated by reshaping the
|
||||||
# Tensor and reusing the existing batch norm implementation
|
# Tensor and reusing the existing batch norm implementation
|
||||||
original_shape = [
|
original_shape = array_ops.shape(inputs)
|
||||||
d if d is not None else -1 for d in inputs.shape.as_list()]
|
original_shape = array_ops.concat([
|
||||||
original_shape = [-1] + original_shape[1:]
|
constant_op.constant([-1]),
|
||||||
expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]
|
original_shape[1:]], axis=0)
|
||||||
|
expanded_shape = array_ops.concat([
|
||||||
|
constant_op.constant([self.virtual_batch_size, -1]),
|
||||||
|
original_shape[1:]], axis=0)
|
||||||
|
|
||||||
# Will cause errors if virtual_batch_size does not divide the batch size
|
# Will cause errors if virtual_batch_size does not divide the batch size
|
||||||
inputs = array_ops.reshape(inputs, expanded_shape)
|
inputs = array_ops.reshape(inputs, expanded_shape)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user