Tweak layer a bit to be eager friendly.
PiperOrigin-RevId: 168312865
This commit is contained in:
parent
60f15462be
commit
9f848734fc
@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.ops import add_arg_scope
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.contrib.layers.python.layers import initializers
|
||||
from tensorflow.contrib.layers.python.layers import utils
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
@ -2583,7 +2584,8 @@ def softmax(logits, scope=None):
|
||||
logits_2d = array_ops.reshape(logits, [-1, num_logits])
|
||||
predictions = nn.softmax(logits_2d)
|
||||
predictions = array_ops.reshape(predictions, array_ops.shape(logits))
|
||||
predictions.set_shape(logits.get_shape())
|
||||
if context.in_graph_mode():
|
||||
predictions.set_shape(logits.get_shape())
|
||||
return predictions
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user