From 7d0573f0d7b551257dc013a98eaaead5f509c4d0 Mon Sep 17 00:00:00 2001 From: Ben Lee <blee@google.com> Date: Thu, 17 Nov 2016 11:15:51 -0800 Subject: [PATCH] Remove spurious unpack/shape from fully_connected if not needed. Change: 139483354 --- tensorflow/contrib/layers/python/layers/layers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 633d4f9c06e..6b2bdb9970b 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1296,9 +1296,6 @@ def fully_connected(inputs, static_shape = inputs_shape.as_list() static_shape[-1] = num_outputs - out_shape = array_ops.unpack(array_ops.shape(inputs)) - out_shape[-1] = num_outputs - weights_shape = [num_input_units, num_outputs] weights_collections = utils.get_variable_collections( variables_collections, 'weights') @@ -1310,6 +1307,8 @@ def fully_connected(inputs, collections=weights_collections, trainable=trainable) if len(static_shape) > 2: + out_shape = array_ops.unpack(array_ops.shape(inputs)) + out_shape[-1] = num_outputs # Reshape inputs inputs = array_ops.reshape(inputs, [-1, num_input_units]) outputs = standard_ops.matmul(inputs, weights)