Adding add_loss, get_losses, get_total_loss to loss_ops.py

Change: 123279011
This commit is contained in:
A. Unique TensorFlower 2016-05-25 18:25:51 -08:00 committed by TensorFlower Gardener
parent 710edb74e6
commit 0f79008185
2 changed files with 61 additions and 1 deletions

View File

@ -19,7 +19,10 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.losses.python.losses.loss_ops import absolute_difference
from tensorflow.contrib.losses.python.losses.loss_ops import add_loss
from tensorflow.contrib.losses.python.losses.loss_ops import cosine_distance
from tensorflow.contrib.losses.python.losses.loss_ops import get_losses
from tensorflow.contrib.losses.python.losses.loss_ops import get_total_loss
from tensorflow.contrib.losses.python.losses.loss_ops import log
from tensorflow.contrib.losses.python.losses.loss_ops import sigmoid_cross_entropy
from tensorflow.contrib.losses.python.losses.loss_ops import softmax_cross_entropy

View File

@ -104,9 +104,11 @@ weighted average over the individual prediction errors:
weight = tf.div(weight, tf.size(weight))
loss = tf.contrib.losses.sum_of_squares(predictions, depths, weight)
@@absolute_difference
@@add_loss
@@cosine_distance
@@get_losses
@@get_total_loss
@@log
@@sigmoid_cross_entropy
@@softmax_cross_entropy
@ -252,6 +254,61 @@ def _num_present(losses, weight, per_batch=False):
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
def add_loss(loss):
"""Adds a externally defined loss to collection of losses.
Args:
loss: A loss `Tensor`.
"""
ops.add_to_collection(ops.GraphKeys.LOSSES, loss)
def get_losses(scope=None):
"""Gets the list of loss variables.
Args:
scope: an optional scope for filtering the losses to return.
Returns:
a list of loss variables.
"""
return ops.get_collection(ops.GraphKeys.LOSSES, scope)
def get_regularization_losses(scope=None):
"""Gets the regularization losses.
Args:
scope: an optional scope for filtering the losses to return.
Returns:
A list of loss variables.
"""
return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
def get_total_loss(add_regularization_losses=True, name="total_loss"):
"""Returns a tensor whose value represents the total loss.
Notice that the function adds the given losses to the regularization losses.
Args:
add_regularization_losses: A boolean indicating whether or not to use the
regularization losses in the sum.
name: The name of the returned tensor.
Returns:
A `Tensor` whose value represents the total loss.
Raises:
ValueError: if `losses` is not iterable.
"""
losses = get_losses()
if add_regularization_losses:
losses += get_regularization_losses()
return math_ops.add_n(losses, name=name)
def absolute_difference(predictions, targets, weight=1.0, scope=None):
"""Adds an Absolute Difference loss to the training procedure.