Remove spurious unpack/shape from fully_connected if not needed.
Change: 139483354
This commit is contained in:
parent
adf0dc7c15
commit
7d0573f0d7
@ -1296,9 +1296,6 @@ def fully_connected(inputs,
|
||||
static_shape = inputs_shape.as_list()
|
||||
static_shape[-1] = num_outputs
|
||||
|
||||
out_shape = array_ops.unpack(array_ops.shape(inputs))
|
||||
out_shape[-1] = num_outputs
|
||||
|
||||
weights_shape = [num_input_units, num_outputs]
|
||||
weights_collections = utils.get_variable_collections(
|
||||
variables_collections, 'weights')
|
||||
@ -1310,6 +1307,8 @@ def fully_connected(inputs,
|
||||
collections=weights_collections,
|
||||
trainable=trainable)
|
||||
if len(static_shape) > 2:
|
||||
out_shape = array_ops.unpack(array_ops.shape(inputs))
|
||||
out_shape[-1] = num_outputs
|
||||
# Reshape inputs
|
||||
inputs = array_ops.reshape(inputs, [-1, num_input_units])
|
||||
outputs = standard_ops.matmul(inputs, weights)
|
||||
|
Loading…
Reference in New Issue
Block a user