From 7d0573f0d7b551257dc013a98eaaead5f509c4d0 Mon Sep 17 00:00:00 2001
From: Ben Lee <blee@google.com>
Date: Thu, 17 Nov 2016 11:15:51 -0800
Subject: [PATCH] Remove spurious unpack/shape from fully_connected if not
 needed. Change: 139483354

---
 tensorflow/contrib/layers/python/layers/layers.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 633d4f9c06e..6b2bdb9970b 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -1296,9 +1296,6 @@ def fully_connected(inputs,
     static_shape = inputs_shape.as_list()
     static_shape[-1] = num_outputs
 
-    out_shape = array_ops.unpack(array_ops.shape(inputs))
-    out_shape[-1] = num_outputs
-
     weights_shape = [num_input_units, num_outputs]
     weights_collections = utils.get_variable_collections(
         variables_collections, 'weights')
@@ -1310,6 +1307,8 @@ def fully_connected(inputs,
                                        collections=weights_collections,
                                        trainable=trainable)
     if len(static_shape) > 2:
+      out_shape = array_ops.unpack(array_ops.shape(inputs))
+      out_shape[-1] = num_outputs
       # Reshape inputs
       inputs = array_ops.reshape(inputs, [-1, num_input_units])
     outputs = standard_ops.matmul(inputs, weights)