Remove spurious unpack/shape from fully_connected if not needed.

Change: 139483354
This commit is contained in:
Ben Lee 2016-11-17 11:15:51 -08:00 committed by TensorFlower Gardener
parent adf0dc7c15
commit 7d0573f0d7

View File

@ -1296,9 +1296,6 @@ def fully_connected(inputs,
static_shape = inputs_shape.as_list()
static_shape[-1] = num_outputs
out_shape = array_ops.unpack(array_ops.shape(inputs))
out_shape[-1] = num_outputs
weights_shape = [num_input_units, num_outputs]
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
@ -1310,6 +1307,8 @@ def fully_connected(inputs,
collections=weights_collections,
trainable=trainable)
if len(static_shape) > 2:
out_shape = array_ops.unpack(array_ops.shape(inputs))
out_shape[-1] = num_outputs
# Reshape inputs
inputs = array_ops.reshape(inputs, [-1, num_input_units])
outputs = standard_ops.matmul(inputs, weights)