Conv2DTranspose uses constant height and width when possible.

PiperOrigin-RevId: 306335881
Change-Id: I23f790e6e123e0c1318e3834c830fff646d1876c
This commit is contained in:
Jingyue Wu 2020-04-13 16:43:25 -07:00 committed by TensorFlower Gardener
parent 3af55c9619
commit a70f444d99

View File

@ -1182,7 +1182,18 @@ class Conv2DTranspose(Conv2D):
else: else:
h_axis, w_axis = 1, 2 h_axis, w_axis = 1, 2
height, width = inputs_shape[h_axis], inputs_shape[w_axis] # Use the constant height and weight when possible.
# TODO(scottzhu): Extract this into a utility function that can be applied
# to all convolutional layers, which currently lost the static shape
# information due to tf.shape().
height, width = None, None
if inputs.shape.rank is not None:
dims = inputs.shape.as_list()
height = dims[h_axis]
width = dims[w_axis]
height = height if height is not None else inputs_shape[h_axis]
width = width if width is not None else inputs_shape[w_axis]
kernel_h, kernel_w = self.kernel_size kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides stride_h, stride_w = self.strides