Change arg order for {softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits to be (labels, predictions), and force use of named args to avoid accidents.
Change: 143629623
This commit is contained in:
parent
d9541696b0
commit
333dc32ff7
tensorflow
contrib
distributions/python/ops
layers/python/layers
learn/python/learn
legacy_seq2seq/python/ops
linear_optimizer/python/ops
losses/python/losses
nn/python/ops
tensor_forest/hybrid/python
examples
image_retraining
tutorials/mnist
udacity
g3doc
api_docs/python
how_tos
tutorials/mnist
python
kernel_tests
ops
training
tools/docker/notebooks
@ -147,7 +147,7 @@ class Bernoulli(distribution.Distribution):
|
||||
distribution_util.same_dynamic_shape(logits, event),
|
||||
lambda: (logits, event),
|
||||
lambda: broadcast(logits, event))
|
||||
return -nn.sigmoid_cross_entropy_with_logits(logits, event)
|
||||
return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
|
||||
|
||||
def _prob(self, event):
|
||||
return math_ops.exp(self._log_prob(event))
|
||||
|
@ -202,7 +202,8 @@ class Categorical(distribution.Distribution):
|
||||
logits_shape = array_ops.shape(logits)[:-1]
|
||||
k *= array_ops.ones(logits_shape, dtype=k.dtype)
|
||||
k.set_shape(tensor_shape.TensorShape(logits.get_shape()[:-1]))
|
||||
return -nn_ops.sparse_softmax_cross_entropy_with_logits(logits, k)
|
||||
return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k,
|
||||
logits=logits)
|
||||
|
||||
def _prob(self, k):
|
||||
return math_ops.exp(self._log_prob(k))
|
||||
@ -214,7 +215,8 @@ class Categorical(distribution.Distribution):
|
||||
logits_2d = array_ops.reshape(self.logits, [-1, self.num_classes])
|
||||
histogram_2d = nn_ops.softmax(logits_2d)
|
||||
ret = array_ops.reshape(
|
||||
nn_ops.softmax_cross_entropy_with_logits(logits_2d, histogram_2d),
|
||||
nn_ops.softmax_cross_entropy_with_logits(labels=histogram_2d,
|
||||
logits=logits_2d),
|
||||
self.batch_shape())
|
||||
ret.set_shape(self.get_batch_shape())
|
||||
return ret
|
||||
|
@ -215,7 +215,8 @@ class _OneHotCategorical(distribution.Distribution):
|
||||
else:
|
||||
logits_2d = array_ops.reshape(logits, [-1, self.num_classes])
|
||||
x_2d = array_ops.reshape(x, [-1, self.num_classes])
|
||||
ret = -nn_ops.softmax_cross_entropy_with_logits(logits_2d, x_2d)
|
||||
ret = -nn_ops.softmax_cross_entropy_with_logits(labels=x_2d,
|
||||
logits=logits_2d)
|
||||
ret = array_ops.reshape(ret, logits_shape)
|
||||
return ret
|
||||
|
||||
@ -229,7 +230,8 @@ class _OneHotCategorical(distribution.Distribution):
|
||||
logits_2d = array_ops.reshape(self.logits, [-1, self.num_classes])
|
||||
histogram_2d = nn_ops.softmax(logits_2d)
|
||||
ret = array_ops.reshape(
|
||||
nn_ops.softmax_cross_entropy_with_logits(logits_2d, histogram_2d),
|
||||
nn_ops.softmax_cross_entropy_with_logits(labels=histogram_2d,
|
||||
logits=logits_2d),
|
||||
self.batch_shape())
|
||||
ret.set_shape(self.get_batch_shape())
|
||||
return ret
|
||||
|
@ -491,7 +491,8 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
|
||||
columns_to_tensors=columns_to_tensor,
|
||||
feature_columns=feature_columns,
|
||||
num_outputs=1)
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits)
|
||||
```
|
||||
|
||||
Args:
|
||||
|
@ -406,8 +406,8 @@ def _log_loss_with_two_classes(logits, target):
|
||||
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
|
||||
if len(target.get_shape()) == 1:
|
||||
target = array_ops.expand_dims(target, dim=[1])
|
||||
loss_vec = nn.sigmoid_cross_entropy_with_logits(logits,
|
||||
math_ops.to_float(target))
|
||||
loss_vec = nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=math_ops.to_float(target), logits=logits)
|
||||
return loss_vec
|
||||
|
||||
|
||||
@ -419,7 +419,8 @@ def _softmax_cross_entropy_loss(logits, target):
|
||||
# sparse_softmax_cross_entropy_with_logits requires [batch_size] target.
|
||||
if len(target.get_shape()) == 2:
|
||||
target = array_ops.squeeze(target, squeeze_dims=[1])
|
||||
loss_vec = nn.sparse_softmax_cross_entropy_with_logits(logits, target)
|
||||
loss_vec = nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=target, logits=logits)
|
||||
return loss_vec
|
||||
|
||||
|
||||
|
@ -466,7 +466,7 @@ def _log_loss_with_two_classes(logits, labels):
|
||||
if len(labels.get_shape()) == 1:
|
||||
labels = array_ops.expand_dims(labels, dim=(1,))
|
||||
return nn.sigmoid_cross_entropy_with_logits(
|
||||
logits, math_ops.to_float(labels), name=name)
|
||||
labels=math_ops.to_float(labels), logits=logits, name=name)
|
||||
|
||||
|
||||
def _one_class_to_two_class_logits(logits):
|
||||
@ -669,7 +669,7 @@ def _softmax_cross_entropy_loss(logits, labels):
|
||||
if len(labels.get_shape()) == 2:
|
||||
labels = array_ops.squeeze(labels, squeeze_dims=(1,))
|
||||
return nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits, labels, name=name)
|
||||
labels=labels, logits=logits, name=name)
|
||||
|
||||
|
||||
class _MultiClassHead(_Head):
|
||||
@ -1461,7 +1461,7 @@ def _sigmoid_cross_entropy_loss(logits, labels):
|
||||
(logits, labels)) as name:
|
||||
# sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
|
||||
return nn.sigmoid_cross_entropy_with_logits(
|
||||
logits, math_ops.to_float(labels), name=name)
|
||||
labels=math_ops.to_float(labels), logits=logits, name=name)
|
||||
|
||||
|
||||
def _float_weights_or_none(weights):
|
||||
|
@ -47,7 +47,7 @@ def sequence_classifier(decoding, labels, sampling_decoding=None, name=None):
|
||||
predictions, xent_list = [], []
|
||||
for i, pred in enumerate(decoding):
|
||||
xent_list.append(nn.softmax_cross_entropy_with_logits(
|
||||
pred, labels[i],
|
||||
labels=labels[i], logits=pred,
|
||||
name="sequence_loss/xent_raw{0}".format(i)))
|
||||
if sampling_decoding:
|
||||
predictions.append(nn.softmax(sampling_decoding[i]))
|
||||
|
@ -1062,7 +1062,7 @@ def sequence_loss_by_example(logits,
|
||||
# violates our general scalar strictness policy.
|
||||
target = array_ops.reshape(target, [-1])
|
||||
crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=logit, labels=target)
|
||||
labels=target, logits=logit)
|
||||
else:
|
||||
crossent = softmax_loss_function(target, logit)
|
||||
log_perp_list.append(crossent * weight)
|
||||
|
@ -456,10 +456,10 @@ class SdcaModel(object):
|
||||
dtypes.float64)
|
||||
|
||||
if self._options['loss_type'] == 'logistic_loss':
|
||||
return math_ops.reduce_sum(
|
||||
math_ops.multiply(
|
||||
sigmoid_cross_entropy_with_logits(predictions, labels),
|
||||
weights)) / math_ops.reduce_sum(weights)
|
||||
return math_ops.reduce_sum(math_ops.multiply(
|
||||
sigmoid_cross_entropy_with_logits(labels=labels,
|
||||
logits=predictions),
|
||||
weights)) / math_ops.reduce_sum(weights)
|
||||
|
||||
if self._options['loss_type'] in ['hinge_loss', 'smooth_hinge_loss']:
|
||||
# hinge_loss = max{0, 1 - y_i w*x} where y_i \in {-1, 1}. So, we need to
|
||||
|
@ -340,7 +340,8 @@ def sigmoid_cross_entropy(
|
||||
multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
|
||||
0.5 * label_smoothing)
|
||||
|
||||
losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels,
|
||||
losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels,
|
||||
logits=logits,
|
||||
name="xentropy")
|
||||
return compute_weighted_loss(losses, weights, scope=scope)
|
||||
|
||||
@ -387,7 +388,8 @@ def softmax_cross_entropy(
|
||||
smooth_negatives = label_smoothing / num_classes
|
||||
onehot_labels = onehot_labels * smooth_positives + smooth_negatives
|
||||
|
||||
losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
|
||||
losses = nn.softmax_cross_entropy_with_logits(labels=onehot_labels,
|
||||
logits=logits,
|
||||
name="xentropy")
|
||||
return compute_weighted_loss(losses, weights, scope=scope)
|
||||
|
||||
@ -421,7 +423,8 @@ def sparse_softmax_cross_entropy(logits, labels, weights=1.0, scope=None):
|
||||
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
|
||||
weights = array_ops.squeeze(weights)
|
||||
|
||||
losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
|
||||
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits,
|
||||
name="xentropy")
|
||||
return compute_weighted_loss(losses, weights, scope=scope)
|
||||
|
||||
|
@ -65,7 +65,7 @@ def deprecated_flipped_softmax_cross_entropy_with_logits(logits,
|
||||
softmax cross entropy loss.
|
||||
"""
|
||||
return nn.softmax_cross_entropy_with_logits(
|
||||
logits=logits, labels=labels, dim=dim, name=name)
|
||||
labels=labels, logits=logits, dim=dim, name=name)
|
||||
|
||||
|
||||
# TODO(b/33392402): Formally deprecate this API.
|
||||
@ -119,7 +119,7 @@ def deprecated_flipped_sparse_softmax_cross_entropy_with_logits(logits,
|
||||
of the labels is not equal to the rank of the labels minus one.
|
||||
"""
|
||||
return nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=logits, labels=labels, name=name)
|
||||
labels=labels, logits=logits, name=name)
|
||||
|
||||
|
||||
# TODO(b/33392402): Formally deprecate this API.
|
||||
@ -174,4 +174,4 @@ def deprecated_flipped_sigmoid_cross_entropy_with_logits(logits,
|
||||
ValueError: If `logits` and `targets` do not have the same shape.
|
||||
"""
|
||||
return nn.sigmoid_cross_entropy_with_logits(
|
||||
logits=logits, targets=targets, name=name)
|
||||
labels=targets, logits=logits, name=name)
|
||||
|
@ -117,8 +117,8 @@ class HybridModel(object):
|
||||
else:
|
||||
loss = math_ops.reduce_mean(
|
||||
nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||
self.training_inference_graph(data),
|
||||
array_ops.squeeze(math_ops.to_int32(labels))),
|
||||
labels=array_ops.squeeze(math_ops.to_int32(labels)),
|
||||
logits=self.training_inference_graph(data)),
|
||||
name="loss")
|
||||
if self.regularizer:
|
||||
loss += layers.apply_regularization(self.regularizer,
|
||||
|
@ -723,7 +723,7 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
|
||||
|
||||
with tf.name_scope('cross_entropy'):
|
||||
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
|
||||
logits, ground_truth_input)
|
||||
labels=ground_truth_input, logits=logits)
|
||||
with tf.name_scope('total'):
|
||||
cross_entropy_mean = tf.reduce_mean(cross_entropy)
|
||||
tf.summary.scalar('cross_entropy', cross_entropy_mean)
|
||||
|
@ -95,9 +95,8 @@ def loss(logits, labels):
|
||||
"""
|
||||
labels = tf.to_int64(labels)
|
||||
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits, labels, name='xentropy')
|
||||
loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
|
||||
return loss
|
||||
labels=labels, logits=logits, name='xentropy')
|
||||
return tf.reduce_mean(cross_entropy, name='xentropy_mean')
|
||||
|
||||
|
||||
def training(loss, learning_rate):
|
||||
|
@ -54,7 +54,8 @@ def main(_):
|
||||
#
|
||||
# So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
|
||||
# outputs of 'y', and then average across the batch.
|
||||
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
|
||||
cross_entropy = tf.reduce_mean(
|
||||
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
|
||||
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
|
||||
|
||||
sess = tf.InteractiveSession()
|
||||
|
@ -119,7 +119,7 @@ def train():
|
||||
# So here we use tf.nn.softmax_cross_entropy_with_logits on the
|
||||
# raw outputs of the nn_layer above, and then average across
|
||||
# the batch.
|
||||
diff = tf.nn.softmax_cross_entropy_with_logits(y, y_)
|
||||
diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)
|
||||
with tf.name_scope('total'):
|
||||
cross_entropy = tf.reduce_mean(diff)
|
||||
tf.summary.scalar('cross_entropy', cross_entropy)
|
||||
|
@ -271,7 +271,7 @@
|
||||
" # cross-entropy across all training examples: that's our loss.\n",
|
||||
" logits = tf.matmul(tf_train_dataset, weights) + biases\n",
|
||||
" loss = tf.reduce_mean(\n",
|
||||
" tf.nn.softmax_cross_entropy_with_logits(logits, tf_train_labels))\n",
|
||||
" tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=logits))\n",
|
||||
" \n",
|
||||
" # Optimizer.\n",
|
||||
" # We are going to find the minimum of this loss using gradient descent.\n",
|
||||
@ -448,7 +448,7 @@
|
||||
" # Training computation.\n",
|
||||
" logits = tf.matmul(tf_train_dataset, weights) + biases\n",
|
||||
" loss = tf.reduce_mean(\n",
|
||||
" tf.nn.softmax_cross_entropy_with_logits(logits, tf_train_labels))\n",
|
||||
" tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=logits))\n",
|
||||
" \n",
|
||||
" # Optimizer.\n",
|
||||
" optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(loss)\n",
|
||||
|
@ -286,7 +286,7 @@
|
||||
" # Training computation.\n",
|
||||
" logits = model(tf_train_dataset)\n",
|
||||
" loss = tf.reduce_mean(\n",
|
||||
" tf.nn.softmax_cross_entropy_with_logits(logits, tf_train_labels))\n",
|
||||
" tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=logits))\n",
|
||||
" \n",
|
||||
" # Optimizer.\n",
|
||||
" optimizer = tf.train.GradientDescentOptimizer(0.05).minimize(loss)\n",
|
||||
|
@ -576,7 +576,7 @@
|
||||
" logits = tf.nn.xw_plus_b(tf.concat_v2(outputs, 0), w, b)\n",
|
||||
" loss = tf.reduce_mean(\n",
|
||||
" tf.nn.softmax_cross_entropy_with_logits(\n",
|
||||
" logits, tf.concat_v2(train_labels, 0)))\n",
|
||||
" labels=tf.concat_v2(train_labels, 0), logits=logits))\n",
|
||||
"\n",
|
||||
" # Optimizer.\n",
|
||||
" global_step = tf.Variable(0)\n",
|
||||
|
@ -1768,7 +1768,8 @@ Example:
|
||||
columns_to_tensors=columns_to_tensor,
|
||||
feature_columns=feature_columns,
|
||||
num_outputs=1)
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits)
|
||||
```
|
||||
|
||||
##### Args:
|
||||
|
@ -20,7 +20,8 @@ Example:
|
||||
columns_to_tensors=columns_to_tensor,
|
||||
feature_columns=feature_columns,
|
||||
num_outputs=1)
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits)
|
||||
```
|
||||
|
||||
##### Args:
|
||||
|
@ -227,9 +227,8 @@ Here are some of the typical usage models:
|
||||
onehot_labels = tf.sparse_to_dense(
|
||||
concated, tf.pack([batch_size, 10]), 1.0, 0.0)
|
||||
logits = tf.get_collection("logits")[0]
|
||||
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
|
||||
onehot_labels,
|
||||
name="xentropy")
|
||||
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
|
||||
labels=onehot_labels, logits=logits, name="xentropy")
|
||||
loss = tf.reduce_mean(cross_entropy, name="xentropy_mean")
|
||||
|
||||
tf.summary.scalar('loss', loss)
|
||||
|
@ -135,7 +135,7 @@ with tf.name_scope('cross_entropy'):
|
||||
# So here we use tf.nn.softmax_cross_entropy_with_logits on the
|
||||
# raw outputs of the nn_layer above, and then average across
|
||||
# the batch.
|
||||
diff = tf.nn.softmax_cross_entropy_with_logits(y, y_)
|
||||
diff = tf.nn.softmax_cross_entropy_with_logits(targets=y_, logits=y)
|
||||
with tf.name_scope('total'):
|
||||
cross_entropy = tf.reduce_mean(diff)
|
||||
tf.summary.scalar('cross_entropy', cross_entropy)
|
||||
|
@ -173,7 +173,8 @@ between the target and the softmax activation function applied to the model's
|
||||
prediction. As in the beginners tutorial, we use the stable formulation:
|
||||
|
||||
```python
|
||||
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
|
||||
cross_entropy = tf.reduce_mean(
|
||||
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
|
||||
```
|
||||
|
||||
Note that `tf.nn.softmax_cross_entropy_with_logits` internally applies the
|
||||
@ -394,7 +395,8 @@ Feel free to go ahead and run this code, but it does 20,000 training iterations
|
||||
and may take a while (possibly up to half an hour), depending on your processor.
|
||||
|
||||
```python
|
||||
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv, y_))
|
||||
cross_entropy = tf.reduce_mean(
|
||||
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
|
||||
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
|
||||
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
|
||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
||||
|
@ -171,7 +171,7 @@ First, the values from the `labels_placeholder` are converted to 64-bit integers
|
||||
```python
|
||||
labels = tf.to_int64(labels)
|
||||
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits, labels, name='xentropy')
|
||||
labels=labels, logits=logits, name='xentropy')
|
||||
```
|
||||
|
||||
It then uses [`tf.reduce_mean`](../../../api_docs/python/math_ops.md#reduce_mean)
|
||||
|
@ -141,25 +141,26 @@ class SparseXentTest(test.TestCase):
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.assertRaisesRegexp(ValueError, ".*Rank mismatch:*"):
|
||||
nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||
[[0., 1.], [2., 3.], [2., 3.]], [[0, 2]])
|
||||
labels=[[0, 2]], logits=[[0., 1.], [2., 3.], [2., 3.]])
|
||||
|
||||
def testScalar(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.assertRaisesRegexp(ValueError, ".*Logits cannot be scalars*"):
|
||||
nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||
constant_op.constant(1.0), constant_op.constant(0))
|
||||
labels=constant_op.constant(0), logits=constant_op.constant(1.0))
|
||||
|
||||
def testLabelsPlaceholderScalar(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
labels = array_ops.placeholder(np.int32)
|
||||
y = nn_ops.sparse_softmax_cross_entropy_with_logits([[7.]], labels)
|
||||
y = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=[[7.]])
|
||||
with self.assertRaisesOpError("labels must be 1-D"):
|
||||
y.eval(feed_dict={labels: 0})
|
||||
|
||||
def testVector(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||
constant_op.constant([1.0]), constant_op.constant(0))
|
||||
labels=constant_op.constant(0), logits=constant_op.constant([1.0]))
|
||||
self.assertAllClose(0.0, loss.eval())
|
||||
|
||||
def testFloat(self):
|
||||
@ -191,7 +192,8 @@ class SparseXentTest(test.TestCase):
|
||||
shape=[3, 4],
|
||||
dtype=dtypes.float64,
|
||||
name="f")
|
||||
x = nn_ops.sparse_softmax_cross_entropy_with_logits(f, l, name="xent")
|
||||
x = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=l, logits=f, name="xent")
|
||||
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
|
||||
print("cross entropy gradient err = ", err)
|
||||
self.assertLess(err, 5e-8)
|
||||
@ -201,7 +203,8 @@ class SparseXentTest(test.TestCase):
|
||||
# manually reshape loss
|
||||
np_loss = np.reshape(np_loss, np.array(labels).shape)
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels)
|
||||
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=features)
|
||||
backprop = loss.op.inputs[0].op.outputs[1]
|
||||
tf_loss, tf_backprop = sess.run([loss, backprop])
|
||||
self.assertAllCloseAccordingToType(np_loss, tf_loss)
|
||||
@ -225,7 +228,7 @@ class SparseXentTest(test.TestCase):
|
||||
labels = array_ops.placeholder(dtypes.int32, shape=[None, 1])
|
||||
logits = array_ops.placeholder(dtypes.float32, shape=[None, 3])
|
||||
ce = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||
logits, array_ops.squeeze(labels))
|
||||
labels=array_ops.squeeze(labels), logits=logits)
|
||||
labels_v2 = np.zeros((1, 1), dtype=np.int32)
|
||||
logits_v2 = np.random.randn(1, 3)
|
||||
sess.run([ce], feed_dict={labels: labels_v2, logits: logits_v2})
|
||||
@ -243,7 +246,7 @@ def _sparse_vs_dense_xent_benchmark_dense(labels, logits):
|
||||
array_ops.stack([length]), 1.0, 0.0)
|
||||
target = array_ops.reshape(target, array_ops.stack([-1, num_entries]))
|
||||
crossent = nn_ops.softmax_cross_entropy_with_logits(
|
||||
logits, target, name="SequenceLoss/CrossEntropy")
|
||||
labels=target, logits=logits, name="SequenceLoss/CrossEntropy")
|
||||
crossent_sum = math_ops.reduce_sum(crossent)
|
||||
grads = gradients_impl.gradients([crossent_sum], [logits])[0]
|
||||
|
||||
|
@ -57,7 +57,7 @@ class XentTest(test.TestCase):
|
||||
np_loss, _ = self._npXent(np_features, np_labels, dim=dim)
|
||||
with self.test_session(use_gpu=use_gpu) as sess:
|
||||
loss = nn_ops.softmax_cross_entropy_with_logits(
|
||||
np_features, np_labels, dim=dim)
|
||||
labels=np_labels, logits=np_features, dim=dim)
|
||||
tf_loss = sess.run(loss)
|
||||
print("np_loss:", np_loss)
|
||||
print("tf_loss:", tf_loss)
|
||||
@ -166,7 +166,8 @@ class XentTest(test.TestCase):
|
||||
shape=[3, 4],
|
||||
dtype=dtypes.float64,
|
||||
name="f")
|
||||
x = nn_ops.softmax_cross_entropy_with_logits(f, l, name="xent")
|
||||
x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f,
|
||||
name="xent")
|
||||
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
|
||||
print("cross entropy gradient err = ", err)
|
||||
self.assertLess(err, 5e-8)
|
||||
|
@ -267,7 +267,7 @@ class MiniMNISTTest(test.TestCase):
|
||||
dtype=dtypes.float64,
|
||||
name="labels")
|
||||
cost = nn_ops.softmax_cross_entropy_with_logits(
|
||||
logits, labels, name="cost")
|
||||
labels=labels, logits=logits, name="cost")
|
||||
|
||||
# Test the gradients.
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
|
@ -559,7 +559,8 @@ def sigmoid_cross_entropy(
|
||||
multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
|
||||
0.5 * label_smoothing)
|
||||
|
||||
losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels,
|
||||
losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels,
|
||||
logits=logits,
|
||||
name="xentropy")
|
||||
return compute_weighted_loss(losses, weights, scope, loss_collection)
|
||||
|
||||
@ -613,7 +614,8 @@ def softmax_cross_entropy(
|
||||
smooth_negatives = label_smoothing / num_classes
|
||||
onehot_labels = onehot_labels * smooth_positives + smooth_negatives
|
||||
|
||||
losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
|
||||
losses = nn.softmax_cross_entropy_with_logits(labels=onehot_labels,
|
||||
logits=logits,
|
||||
name="xentropy")
|
||||
return compute_weighted_loss(losses, weights, scope, loss_collection)
|
||||
|
||||
@ -653,7 +655,8 @@ def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None,
|
||||
[logits, labels, weights]) as scope:
|
||||
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
|
||||
|
||||
losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
|
||||
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits,
|
||||
name="xentropy")
|
||||
# Reshape losses to [batch_size, 1] to be consistent with weights.
|
||||
losses = array_ops.reshape(losses, shape=[array_ops.shape(losses)[0], 1])
|
||||
|
@ -96,7 +96,9 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
|
||||
return result
|
||||
|
||||
|
||||
def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
|
||||
def sigmoid_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
|
||||
labels=None, logits=None,
|
||||
name=None):
|
||||
"""Computes sigmoid cross entropy given `logits`.
|
||||
|
||||
Measures the probability error in discrete classification tasks in which each
|
||||
@ -104,7 +106,7 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
|
||||
perform multilabel classification where a picture can contain both an elephant
|
||||
and a dog at the same time.
|
||||
|
||||
For brevity, let `x = logits`, `z = targets`. The logistic loss is
|
||||
For brevity, let `x = logits`, `z = labels`. The logistic loss is
|
||||
|
||||
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
|
||||
@ -124,11 +126,12 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
|
||||
|
||||
max(x, 0) - x * z + log(1 + exp(-abs(x)))
|
||||
|
||||
`logits` and `targets` must have the same type and shape.
|
||||
`logits` and `labels` must have the same type and shape.
|
||||
|
||||
Args:
|
||||
_sentinel: Used to prevent positional parameters. Internal, do not use.
|
||||
labels: A `Tensor` of the same type and shape as `logits`.
|
||||
logits: A `Tensor` of type `float32` or `float64`.
|
||||
targets: A `Tensor` of the same type and shape as `logits`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -136,16 +139,21 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
|
||||
logistic losses.
|
||||
|
||||
Raises:
|
||||
ValueError: If `logits` and `targets` do not have the same shape.
|
||||
ValueError: If `logits` and `labels` do not have the same shape.
|
||||
"""
|
||||
with ops.name_scope(name, "logistic_loss", [logits, targets]) as name:
|
||||
# pylint: disable=protected-access
|
||||
nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits",
|
||||
_sentinel, labels, logits)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
|
||||
logits = ops.convert_to_tensor(logits, name="logits")
|
||||
targets = ops.convert_to_tensor(targets, name="targets")
|
||||
labels = ops.convert_to_tensor(labels, name="labels")
|
||||
try:
|
||||
targets.get_shape().merge_with(logits.get_shape())
|
||||
labels.get_shape().merge_with(logits.get_shape())
|
||||
except ValueError:
|
||||
raise ValueError("logits and targets must have the same shape (%s vs %s)"
|
||||
% (logits.get_shape(), targets.get_shape()))
|
||||
raise ValueError("logits and labels must have the same shape (%s vs %s)"
|
||||
% (logits.get_shape(), labels.get_shape()))
|
||||
|
||||
# The logistic loss formula from above is
|
||||
# x - x * z + log(1 + exp(-x))
|
||||
@ -159,7 +167,7 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
|
||||
cond = (logits >= zeros)
|
||||
relu_logits = array_ops.where(cond, logits, zeros)
|
||||
neg_abs_logits = array_ops.where(cond, -logits, logits)
|
||||
return math_ops.add(relu_logits - logits * targets,
|
||||
return math_ops.add(relu_logits - logits * labels,
|
||||
math_ops.log1p(math_ops.exp(neg_abs_logits)),
|
||||
name=name)
|
||||
|
||||
@ -1095,7 +1103,7 @@ def nce_loss(weights,
|
||||
partition_strategy=partition_strategy,
|
||||
name=name)
|
||||
sampled_losses = sigmoid_cross_entropy_with_logits(
|
||||
logits, labels, name="sampled_losses")
|
||||
labels=labels, logits=logits, name="sampled_losses")
|
||||
# sampled_losses is batch_size x {true_loss, sampled_losses...}
|
||||
# We sum out true and sampled losses.
|
||||
return _sum_rows(sampled_losses)
|
||||
@ -1170,6 +1178,7 @@ def sampled_softmax_loss(weights,
|
||||
remove_accidental_hits=remove_accidental_hits,
|
||||
partition_strategy=partition_strategy,
|
||||
name=name)
|
||||
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels)
|
||||
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits)
|
||||
# sampled_losses is a [batch_size] tensor.
|
||||
return sampled_losses
|
||||
|
@ -14,7 +14,6 @@
|
||||
# ==============================================================================
|
||||
"""Wrappers for primitive Neural Net (NN) Operations."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -1047,7 +1046,7 @@ def conv2d_transpose(value,
|
||||
raise ValueError("data_format has to be either NCHW or NHWC.")
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
filter = ops.convert_to_tensor(filter, name="filter")
|
||||
axis = 3 if data_format=="NHWC" else 1
|
||||
axis = 3 if data_format == "NHWC" else 1
|
||||
if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[3]):
|
||||
raise ValueError("input channels does not match filter's input channels, "
|
||||
"{} != {}".format(value.get_shape()[3], filter.get_shape(
|
||||
@ -1528,7 +1527,18 @@ def log_softmax(logits, dim=-1, name=None):
|
||||
return _softmax(logits, gen_nn_ops._log_softmax, dim, name)
|
||||
|
||||
|
||||
def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
|
||||
def _ensure_xent_args(name, sentinel, labels, logits):
|
||||
# Make sure that all arguments were passed as named arguments.
|
||||
if sentinel is not None:
|
||||
raise ValueError("Only call `%s` with "
|
||||
"named arguments (labels=..., logits=..., ...)" % name)
|
||||
if labels is None or logits is None:
|
||||
raise ValueError("Both labels and logits must be provided.")
|
||||
|
||||
|
||||
def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
|
||||
labels=None, logits=None,
|
||||
dim=-1, name=None):
|
||||
"""Computes softmax cross entropy between `logits` and `labels`.
|
||||
|
||||
Measures the probability error in discrete classification tasks in which the
|
||||
@ -1551,9 +1561,13 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
|
||||
`logits` and `labels` must have the same shape `[batch_size, num_classes]`
|
||||
and the same dtype (either `float16`, `float32`, or `float64`).
|
||||
|
||||
**Note that to avoid confusion, it is required to pass only named arguments to
|
||||
this function.**
|
||||
|
||||
Args:
|
||||
logits: Unscaled log probabilities.
|
||||
_sentinel: Used to prevent positional parameters. Internal, do not use.
|
||||
labels: Each row `labels[i]` must be a valid probability distribution.
|
||||
logits: Unscaled log probabilities.
|
||||
dim: The class dimension. Defaulted to -1 which is the last dimension.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
@ -1561,6 +1575,9 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
|
||||
A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
|
||||
softmax cross entropy loss.
|
||||
"""
|
||||
_ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel,
|
||||
labels, logits)
|
||||
|
||||
# TODO(pcmurray) Raise an error when the labels do not sum to 1. Note: This
|
||||
# could break users who call this with bad labels, but disregard the bad
|
||||
# results.
|
||||
@ -1569,7 +1586,7 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
|
||||
labels = ops.convert_to_tensor(labels)
|
||||
precise_logits = math_ops.cast(logits, dtypes.float32) if (
|
||||
logits.dtype == dtypes.float16) else logits
|
||||
# Labels and logits must be of the same type
|
||||
# labels and logits must be of the same type
|
||||
labels = math_ops.cast(labels, precise_logits.dtype)
|
||||
input_rank = array_ops.rank(precise_logits)
|
||||
# For shape inference.
|
||||
@ -1618,7 +1635,9 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
|
||||
return cost
|
||||
|
||||
|
||||
def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
|
||||
def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
|
||||
labels=None, logits=None,
|
||||
name=None):
|
||||
"""Computes sparse softmax cross entropy between `logits` and `labels`.
|
||||
|
||||
Measures the probability error in discrete classification tasks in which the
|
||||
@ -1640,14 +1659,18 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
|
||||
A common use case is to have logits of shape `[batch_size, num_classes]` and
|
||||
labels of shape `[batch_size]`. But higher dimensions are supported.
|
||||
|
||||
**Note that to avoid confusion, it is required to pass only named arguments to
|
||||
this function.**
|
||||
|
||||
Args:
|
||||
logits: Unscaled log probabilities of rank `r` and shape
|
||||
`[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`.
|
||||
_sentinel: Used to prevent positional parameters. Internal, do not use.
|
||||
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-2}]` and dtype `int32` or
|
||||
`int64`. Each entry in `labels` must be an index in `[0, num_classes)`.
|
||||
Other values will raise an exception when this op is run on CPU, and
|
||||
return `NaN` for corresponding corresponding loss and gradient rows
|
||||
on GPU.
|
||||
logits: Unscaled log probabilities of rank `r` and shape
|
||||
`[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -1658,6 +1681,9 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
|
||||
ValueError: If logits are scalars (need to have rank >= 1) or if the rank
|
||||
of the labels is not equal to the rank of the labels minus one.
|
||||
"""
|
||||
_ensure_xent_args("sparse_softmax_cross_entropy_with_logits", _sentinel,
|
||||
labels, logits)
|
||||
|
||||
# TODO(pcmurray) Raise an error when the label is not an index in
|
||||
# [0, num_classes). Note: This could break users who call this with bad
|
||||
# labels, but disregard the bad results.
|
||||
@ -1679,8 +1705,8 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
|
||||
if logits.get_shape().ndims is not None and (
|
||||
labels_static_shape.ndims is not None and
|
||||
labels_static_shape.ndims != logits.get_shape().ndims - 1):
|
||||
raise ValueError("Rank mismatch: Rank of labels (received %s) should equal "
|
||||
"rank of logits minus 1 (received %s)." %
|
||||
raise ValueError("Rank mismatch: Rank of labels (received %s) should "
|
||||
"equal rank of logits minus 1 (received %s)." %
|
||||
(labels_static_shape.ndims, logits.get_shape().ndims))
|
||||
# Check if no reshapes are required.
|
||||
if logits.get_shape().ndims == 2:
|
||||
@ -1857,8 +1883,7 @@ def xw_plus_b_v1(x, weights, biases, name=None): # pylint: disable=invalid-name
|
||||
return bias_add_v1(mm, biases, name=name)
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
|
||||
def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
|
||||
"""Computes dropout.
|
||||
|
||||
With probability `keep_prob`, outputs the input element scaled up by
|
||||
@ -2082,5 +2107,3 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
|
||||
rates=rates,
|
||||
padding=padding,
|
||||
name=name))
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
@ -57,7 +57,7 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
|
||||
with self.test_session():
|
||||
logits, targets, _ = self._Inputs()
|
||||
loss = nn_impl.sigmoid_cross_entropy_with_logits(
|
||||
logits, targets, name="mylogistic")
|
||||
labels=targets, logits=logits, name="mylogistic")
|
||||
self.assertEqual("mylogistic", loss.op.name)
|
||||
|
||||
def testLogisticOutput(self):
|
||||
@ -65,7 +65,8 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
|
||||
for dtype in [dtypes.float32, dtypes.float16]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
logits, targets, losses = self._Inputs(dtype=dtype)
|
||||
loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
|
||||
loss = nn_impl.sigmoid_cross_entropy_with_logits(
|
||||
labels=targets, logits=logits)
|
||||
np_loss = np.array(losses).astype(np.float32)
|
||||
tf_loss = loss.eval()
|
||||
self.assertAllClose(np_loss, tf_loss, atol=0.001)
|
||||
@ -75,7 +76,8 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
|
||||
for dtype in [dtypes.float32, dtypes.float16]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
logits, targets, losses = self._Inputs(dtype=dtype, sizes=[2, 2, 2])
|
||||
loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
|
||||
loss = nn_impl.sigmoid_cross_entropy_with_logits(
|
||||
labels=targets, logits=logits)
|
||||
np_loss = np.array(losses).astype(np.float32)
|
||||
tf_loss = loss.eval()
|
||||
self.assertAllClose(np_loss, tf_loss, atol=0.001)
|
||||
@ -84,7 +86,8 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
|
||||
sizes = [4, 2]
|
||||
with self.test_session():
|
||||
logits, targets, _ = self._Inputs(sizes=sizes)
|
||||
loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
|
||||
loss = nn_impl.sigmoid_cross_entropy_with_logits(
|
||||
labels=targets, logits=logits)
|
||||
err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
|
||||
print("logistic loss gradient err = ", err)
|
||||
self.assertLess(err, 1e-7)
|
||||
@ -93,13 +96,15 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
|
||||
with self.test_session():
|
||||
logits = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
|
||||
targets = constant_op.constant([0.0, 1.0], dtype=dtypes.float64)
|
||||
loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
|
||||
loss = nn_impl.sigmoid_cross_entropy_with_logits(
|
||||
labels=targets, logits=logits)
|
||||
grads = gradients_impl.gradients(loss, logits)[0].eval()
|
||||
self.assertAllClose(grads, [0.5, -0.5])
|
||||
|
||||
def testShapeError(self):
|
||||
with self.assertRaisesRegexp(ValueError, "must have the same shape"):
|
||||
nn_impl.sigmoid_cross_entropy_with_logits([[2, 1]], [1, 2, 3])
|
||||
nn_impl.sigmoid_cross_entropy_with_logits(labels=[1, 2, 3],
|
||||
logits=[[2, 1]])
|
||||
|
||||
|
||||
class WeightedCrossEntropyTest(test.TestCase):
|
||||
@ -128,15 +133,15 @@ class WeightedCrossEntropyTest(test.TestCase):
|
||||
with self.test_session():
|
||||
logits, targets, pos_weight, _ = self._Inputs()
|
||||
loss = nn_impl.weighted_cross_entropy_with_logits(
|
||||
targets, logits, pos_weight, name="mybce")
|
||||
targets=targets, logits=logits, pos_weight=pos_weight, name="mybce")
|
||||
self.assertEqual("mybce", loss.op.name)
|
||||
|
||||
def testOutput(self):
|
||||
for use_gpu in [True, False]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
logits, targets, pos_weight, losses = self._Inputs(dtype=dtypes.float32)
|
||||
loss = nn_impl.weighted_cross_entropy_with_logits(targets, logits,
|
||||
pos_weight)
|
||||
loss = nn_impl.weighted_cross_entropy_with_logits(
|
||||
targets=targets, logits=logits, pos_weight=pos_weight)
|
||||
np_loss = np.array(losses).astype(np.float32)
|
||||
tf_loss = loss.eval()
|
||||
self.assertAllClose(np_loss, tf_loss, atol=0.001)
|
||||
@ -146,8 +151,8 @@ class WeightedCrossEntropyTest(test.TestCase):
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
logits, targets, pos_weight, losses = self._Inputs(
|
||||
dtype=dtypes.float32, sizes=[2, 2, 2])
|
||||
loss = nn_impl.weighted_cross_entropy_with_logits(targets, logits,
|
||||
pos_weight)
|
||||
loss = nn_impl.weighted_cross_entropy_with_logits(
|
||||
targets=targets, logits=logits, pos_weight=pos_weight)
|
||||
np_loss = np.array(losses).astype(np.float32)
|
||||
tf_loss = loss.eval()
|
||||
self.assertAllClose(np_loss, tf_loss, atol=0.001)
|
||||
@ -156,15 +161,16 @@ class WeightedCrossEntropyTest(test.TestCase):
|
||||
sizes = [4, 2]
|
||||
with self.test_session():
|
||||
logits, targets, pos_weight, _ = self._Inputs(sizes=sizes)
|
||||
loss = nn_impl.weighted_cross_entropy_with_logits(targets, logits,
|
||||
pos_weight)
|
||||
loss = nn_impl.weighted_cross_entropy_with_logits(
|
||||
targets=targets, logits=logits, pos_weight=pos_weight)
|
||||
err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
|
||||
print("logistic loss gradient err = ", err)
|
||||
self.assertLess(err, 1e-7)
|
||||
|
||||
def testShapeError(self):
|
||||
with self.assertRaisesRegexp(ValueError, "must have the same shape"):
|
||||
nn_impl.weighted_cross_entropy_with_logits([1, 2, 3], [[2, 1]], 2.0)
|
||||
nn_impl.weighted_cross_entropy_with_logits(
|
||||
targets=[1, 2, 3], logits=[[2, 1]], pos_weight=2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1614,7 +1614,7 @@ class MetaGraphTest(test.TestCase):
|
||||
concated, array_ops.stack([batch_size, 10]), 1.0, 0.0)
|
||||
logits = ops_lib.get_collection("logits")[0]
|
||||
cross_entropy = nn_ops.softmax_cross_entropy_with_logits(
|
||||
logits, onehot_labels, name="xentropy")
|
||||
labels=onehot_labels, logits=logits, name="xentropy")
|
||||
loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean")
|
||||
|
||||
summary.scalar("loss", loss)
|
||||
@ -1698,7 +1698,8 @@ class MetaGraphTest(test.TestCase):
|
||||
bias = variables.Variable(array_ops.zeros([10]), name="bias")
|
||||
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
|
||||
nn_ops.softmax(logit, name="prediction")
|
||||
cost = nn_ops.softmax_cross_entropy_with_logits(logit, label, name="cost")
|
||||
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
|
||||
logits=logit, name="cost")
|
||||
adam.AdamOptimizer().minimize(cost, name="optimize")
|
||||
saver = saver_module.Saver()
|
||||
sess.run(variables.global_variables_initializer())
|
||||
@ -1726,7 +1727,8 @@ class MetaGraphTest(test.TestCase):
|
||||
bias = variables.Variable(array_ops.zeros([10]), name="bias")
|
||||
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
|
||||
nn_ops.softmax(logit, name="prediction")
|
||||
cost = nn_ops.softmax_cross_entropy_with_logits(logit, label)
|
||||
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
|
||||
logits=logit)
|
||||
adam.AdamOptimizer().minimize(cost, name="optimize")
|
||||
meta_graph_def = saver_module.export_meta_graph()
|
||||
|
||||
@ -1758,7 +1760,8 @@ class MetaGraphTest(test.TestCase):
|
||||
bias = variables.Variable(array_ops.zeros([10]), name="bias")
|
||||
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
|
||||
nn_ops.softmax(logit, name="prediction")
|
||||
cost = nn_ops.softmax_cross_entropy_with_logits(logit, label)
|
||||
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
|
||||
logits=logit)
|
||||
adam.AdamOptimizer().minimize(cost, name="optimize")
|
||||
meta_graph_def = saver_module.export_meta_graph(clear_devices=True)
|
||||
graph_io.write_graph(meta_graph_def, "/tmp", "meta_graph.pbtxt")
|
||||
|
@ -1206,7 +1206,7 @@
|
||||
"# Training computation: logits + cross-entropy loss.\n",
|
||||
"logits = model(train_data_node, True)\n",
|
||||
"loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(\n",
|
||||
" logits, train_labels_node))\n",
|
||||
" labels=train_labels_node, logits=logits))\n",
|
||||
"\n",
|
||||
"# L2 regularization for the fully connected parameters.\n",
|
||||
"regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) +\n",
|
||||
|
Loading…
Reference in New Issue
Block a user