Conv2DTranspose uses constant height and width when possible.
PiperOrigin-RevId: 306335881 Change-Id: I23f790e6e123e0c1318e3834c830fff646d1876c
This commit is contained in:
parent
3af55c9619
commit
a70f444d99
@ -1182,7 +1182,18 @@ class Conv2DTranspose(Conv2D):
|
||||
else:
|
||||
h_axis, w_axis = 1, 2
|
||||
|
||||
height, width = inputs_shape[h_axis], inputs_shape[w_axis]
|
||||
# Use the constant height and weight when possible.
|
||||
# TODO(scottzhu): Extract this into a utility function that can be applied
|
||||
# to all convolutional layers, which currently lost the static shape
|
||||
# information due to tf.shape().
|
||||
height, width = None, None
|
||||
if inputs.shape.rank is not None:
|
||||
dims = inputs.shape.as_list()
|
||||
height = dims[h_axis]
|
||||
width = dims[w_axis]
|
||||
height = height if height is not None else inputs_shape[h_axis]
|
||||
width = width if width is not None else inputs_shape[w_axis]
|
||||
|
||||
kernel_h, kernel_w = self.kernel_size
|
||||
stride_h, stride_w = self.strides
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user