Tweak layer a bit to be eager friendly.
PiperOrigin-RevId: 168312865
This commit is contained in:
parent
60f15462be
commit
9f848734fc
@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.ops import add_arg_scope
|
|||||||
from tensorflow.contrib.framework.python.ops import variables
|
from tensorflow.contrib.framework.python.ops import variables
|
||||||
from tensorflow.contrib.layers.python.layers import initializers
|
from tensorflow.contrib.layers.python.layers import initializers
|
||||||
from tensorflow.contrib.layers.python.layers import utils
|
from tensorflow.contrib.layers.python.layers import utils
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -2583,7 +2584,8 @@ def softmax(logits, scope=None):
|
|||||||
logits_2d = array_ops.reshape(logits, [-1, num_logits])
|
logits_2d = array_ops.reshape(logits, [-1, num_logits])
|
||||||
predictions = nn.softmax(logits_2d)
|
predictions = nn.softmax(logits_2d)
|
||||||
predictions = array_ops.reshape(predictions, array_ops.shape(logits))
|
predictions = array_ops.reshape(predictions, array_ops.shape(logits))
|
||||||
predictions.set_shape(logits.get_shape())
|
if context.in_graph_mode():
|
||||||
|
predictions.set_shape(logits.get_shape())
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user