Remove spurious unpack/shape from fully_connected if not needed.
Change: 139483354
This commit is contained in:
parent
adf0dc7c15
commit
7d0573f0d7
@ -1296,9 +1296,6 @@ def fully_connected(inputs,
|
|||||||
static_shape = inputs_shape.as_list()
|
static_shape = inputs_shape.as_list()
|
||||||
static_shape[-1] = num_outputs
|
static_shape[-1] = num_outputs
|
||||||
|
|
||||||
out_shape = array_ops.unpack(array_ops.shape(inputs))
|
|
||||||
out_shape[-1] = num_outputs
|
|
||||||
|
|
||||||
weights_shape = [num_input_units, num_outputs]
|
weights_shape = [num_input_units, num_outputs]
|
||||||
weights_collections = utils.get_variable_collections(
|
weights_collections = utils.get_variable_collections(
|
||||||
variables_collections, 'weights')
|
variables_collections, 'weights')
|
||||||
@ -1310,6 +1307,8 @@ def fully_connected(inputs,
|
|||||||
collections=weights_collections,
|
collections=weights_collections,
|
||||||
trainable=trainable)
|
trainable=trainable)
|
||||||
if len(static_shape) > 2:
|
if len(static_shape) > 2:
|
||||||
|
out_shape = array_ops.unpack(array_ops.shape(inputs))
|
||||||
|
out_shape[-1] = num_outputs
|
||||||
# Reshape inputs
|
# Reshape inputs
|
||||||
inputs = array_ops.reshape(inputs, [-1, num_input_units])
|
inputs = array_ops.reshape(inputs, [-1, num_input_units])
|
||||||
outputs = standard_ops.matmul(inputs, weights)
|
outputs = standard_ops.matmul(inputs, weights)
|
||||||
|
Loading…
Reference in New Issue
Block a user