From 9f848734fc8e10cf7b99121c2ccbed9249665ae8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 11 Sep 2017 17:47:06 -0700 Subject: [PATCH] Tweak layer a bit to be eager friendly. PiperOrigin-RevId: 168312865 --- tensorflow/contrib/layers/python/layers/layers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 33c31262664..e90793ba333 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.ops import add_arg_scope from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.layers.python.layers import initializers from tensorflow.contrib.layers.python.layers import utils +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops @@ -2583,7 +2584,8 @@ def softmax(logits, scope=None): logits_2d = array_ops.reshape(logits, [-1, num_logits]) predictions = nn.softmax(logits_2d) predictions = array_ops.reshape(predictions, array_ops.shape(logits)) - predictions.set_shape(logits.get_shape()) + if context.in_graph_mode(): + predictions.set_shape(logits.get_shape()) return predictions