Tweak layer a bit to be eager friendly.

PiperOrigin-RevId: 168312865
This commit is contained in:
A. Unique TensorFlower 2017-09-11 17:47:06 -07:00 committed by TensorFlower Gardener
parent 60f15462be
commit 9f848734fc

View File

@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.ops import add_arg_scope
from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import initializers from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import utils from tensorflow.contrib.layers.python.layers import utils
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function from tensorflow.python.framework import function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -2583,7 +2584,8 @@ def softmax(logits, scope=None):
logits_2d = array_ops.reshape(logits, [-1, num_logits]) logits_2d = array_ops.reshape(logits, [-1, num_logits])
predictions = nn.softmax(logits_2d) predictions = nn.softmax(logits_2d)
predictions = array_ops.reshape(predictions, array_ops.shape(logits)) predictions = array_ops.reshape(predictions, array_ops.shape(logits))
predictions.set_shape(logits.get_shape()) if context.in_graph_mode():
predictions.set_shape(logits.get_shape())
return predictions return predictions