Conv2DTranspose uses constant height and width when possible.
PiperOrigin-RevId: 306335881 Change-Id: I23f790e6e123e0c1318e3834c830fff646d1876c
This commit is contained in:
parent
3af55c9619
commit
a70f444d99
@ -1182,7 +1182,18 @@ class Conv2DTranspose(Conv2D):
|
|||||||
else:
|
else:
|
||||||
h_axis, w_axis = 1, 2
|
h_axis, w_axis = 1, 2
|
||||||
|
|
||||||
height, width = inputs_shape[h_axis], inputs_shape[w_axis]
|
# Use the constant height and weight when possible.
|
||||||
|
# TODO(scottzhu): Extract this into a utility function that can be applied
|
||||||
|
# to all convolutional layers, which currently lost the static shape
|
||||||
|
# information due to tf.shape().
|
||||||
|
height, width = None, None
|
||||||
|
if inputs.shape.rank is not None:
|
||||||
|
dims = inputs.shape.as_list()
|
||||||
|
height = dims[h_axis]
|
||||||
|
width = dims[w_axis]
|
||||||
|
height = height if height is not None else inputs_shape[h_axis]
|
||||||
|
width = width if width is not None else inputs_shape[w_axis]
|
||||||
|
|
||||||
kernel_h, kernel_w = self.kernel_size
|
kernel_h, kernel_w = self.kernel_size
|
||||||
stride_h, stride_w = self.strides
|
stride_h, stride_w = self.strides
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user