Cleaning up safe_div implementations

PiperOrigin-RevId: 222124095
This commit is contained in:
Pavithra Vijay 2018-11-19 13:01:19 -08:00 committed by TensorFlower Gardener
parent 9a83d2111f
commit c5695df38c
6 changed files with 42 additions and 182 deletions

View File

@ -22,7 +22,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import add_arg_scope
from tensorflow.python.compat import compat
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -67,34 +66,6 @@ def _scale_losses(losses, weights):
return math_ops.reduce_sum(reduced_losses)
def _safe_div(numerator, denominator, name="value"):
"""Computes a safe divide which returns 0 if the denominator is zero.
Note that the function contains an additional conditional check that is
necessary for avoiding situations where the loss is zero causing NaNs to
creep into the gradient computation.
Args:
numerator: An arbitrary `Tensor`.
denominator: A `Tensor` whose shape matches `numerator` and whose values are
assumed to be non-negative.
name: An optional name for the returned op.
Returns:
The element-wise value of the numerator divided by the denominator.
"""
if compat.forward_compatible(2018, 11, 1):
return math_ops.div_no_nan(numerator, denominator, name=name)
return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.div(numerator,
array_ops.where(
math_ops.equal(denominator, 0),
array_ops.ones_like(denominator), denominator)),
array_ops.zeros_like(numerator),
name=name)
def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
@ -107,7 +78,7 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return _safe_div(total_loss, num_present, name="value")
return math_ops.div_no_nan(total_loss, num_present, name="value")
@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
@ -612,12 +583,12 @@ def mean_pairwise_squared_error(predictions,
math_ops.square(diffs), reduction_indices=reduction_indices)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
term1 = 2.0 * _safe_div(sum_squares_diff_per_batch,
num_present_per_batch,
name="value")
term1 = 2.0 * math_ops.div_no_nan(
sum_squares_diff_per_batch, num_present_per_batch, name="value")
sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
term2 = 2.0 * _safe_div(math_ops.square(sum_diff),
term2 = 2.0 * math_ops.div_no_nan(
math_ops.square(sum_diff),
math_ops.square(num_present_per_batch),
name="value")

View File

@ -24,7 +24,6 @@ from __future__ import print_function
import collections as collections_lib
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -46,32 +45,6 @@ from tensorflow.python.util.deprecation import deprecated
_EPSILON = 1e-7
def _safe_div(numerator, denominator):
"""Computes a safe divide which returns 0 if the denominator is zero.
Note that the function contains an additional conditional check that is
necessary for avoiding situations where the loss is zero causing NaNs to
creep into the gradient computation.
Args:
numerator: An arbitrary `Tensor`.
denominator: A `Tensor` whose shape matches `numerator` and whose values are
assumed to be non-negative.
Returns:
The element-wise value of the numerator divided by the denominator.
"""
if compat.forward_compatible(2018, 11, 1):
return math_ops.div_no_nan(numerator, denominator)
return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.div(numerator,
array_ops.where(
math_ops.equal(denominator, 0),
array_ops.ones_like(denominator), denominator)),
array_ops.zeros_like(numerator))
@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the '
'order of the labels and predictions arguments has been switched.')
def streaming_true_positives(predictions,
@ -3247,24 +3220,20 @@ def streaming_covariance(predictions,
# We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
# batch_mean_prediction is E[x_B] in the update equation
batch_mean_prediction = _safe_div(
math_ops.reduce_sum(weighted_predictions),
batch_count)
delta_mean_prediction = _safe_div(
(batch_mean_prediction - mean_prediction) * batch_count,
update_count)
batch_mean_prediction = math_ops.div_no_nan(
math_ops.reduce_sum(weighted_predictions), batch_count)
delta_mean_prediction = math_ops.div_no_nan(
(batch_mean_prediction - mean_prediction) * batch_count, update_count)
update_mean_prediction = state_ops.assign_add(mean_prediction,
delta_mean_prediction)
# prev_mean_prediction is E[x_A] in the update equation
prev_mean_prediction = update_mean_prediction - delta_mean_prediction
# batch_mean_label is E[y_B] in the update equation
batch_mean_label = _safe_div(
math_ops.reduce_sum(weighted_labels),
batch_count)
delta_mean_label = _safe_div(
(batch_mean_label - mean_label) * batch_count,
update_count)
batch_mean_label = math_ops.div_no_nan(
math_ops.reduce_sum(weighted_labels), batch_count)
delta_mean_label = math_ops.div_no_nan(
(batch_mean_label - mean_label) * batch_count, update_count)
update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
# prev_mean_label is E[y_A] in the update equation
prev_mean_label = update_mean_label - delta_mean_label
@ -3926,9 +3895,8 @@ def cohen_kappa(labels,
po_sum = math_ops.reduce_sum(po)
total = math_ops.reduce_sum(pe_row)
pe_sum = math_ops.reduce_sum(
_safe_div(
math_ops.to_double(pe_row * pe_col),
math_ops.to_double(total)))
math_ops.div_no_nan(
math_ops.to_double(pe_row * pe_col), math_ops.to_double(total)))
po_sum, pe_sum, total = (math_ops.to_double(po_sum),
math_ops.to_double(pe_sum),
math_ops.to_double(total))

View File

@ -654,7 +654,7 @@ def weighted_masked_objective(fn):
score_array = math_ops.multiply(score_array, weights)
score_array = math_ops.reduce_sum(score_array)
weights = math_ops.reduce_sum(weights)
score_array = metrics_module.safe_div(score_array, weights)
score_array = math_ops.div_no_nan(score_array, weights)
return K.mean(score_array)
return weighted

View File

@ -27,7 +27,6 @@ import weakref
from enum import Enum
import six
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes
@ -173,32 +172,6 @@ def weakmethod(method):
return inner
def safe_div(numerator, denominator):
"""Computes a safe divide which returns 0 if the denominator is zero.
Note that the function contains an additional conditional check that is
necessary for avoiding situations where the loss is zero causing NaNs to
creep into the gradient computation.
Args:
numerator: An arbitrary `Tensor`.
denominator: A `Tensor` whose shape matches `numerator` and whose values are
assumed to be non-negative.
Returns:
The element-wise value of the numerator divided by the denominator.
"""
if compat.forward_compatible(2018, 11, 1):
return math_ops.div_no_nan(numerator, denominator)
return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.div(numerator,
array_ops.where(
math_ops.equal(denominator, 0),
array_ops.ones_like(denominator), denominator)),
array_ops.zeros_like(numerator))
def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
"""Squeeze or expand last dimension if needed.
@ -697,7 +670,7 @@ class Mean(Metric):
return ops.convert_to_tensor(update_count_op)
def result(self):
return safe_div(self.total, self.count)
return math_ops.div_no_nan(self.total, self.count)
class MeanMetricWrapper(Mean):

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -92,34 +91,6 @@ class Reduction(ReductionV2):
raise ValueError("Invalid Reduction Key %s." % key)
def _safe_div(numerator, denominator, name="value"):
"""Computes a safe divide which returns 0 if the denominator is zero.
Note that the function contains an additional conditional check that is
necessary for avoiding situations where the loss is zero causing NaNs to
creep into the gradient computation.
Args:
numerator: An arbitrary `Tensor`.
denominator: A `Tensor` whose shape matches `numerator` and whose values are
assumed to be non-negative.
name: An optional name for the returned op.
Returns:
The element-wise value of the numerator divided by the denominator.
"""
if compat.forward_compatible(2018, 11, 1):
return math_ops.div_no_nan(numerator, denominator, name=name)
return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.div(numerator,
array_ops.where(
math_ops.equal(denominator, 0),
array_ops.ones_like(denominator), denominator)),
array_ops.zeros_like(numerator),
name=name)
def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
@ -132,7 +103,7 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return _safe_div(total_loss, num_present)
return math_ops.div_no_nan(total_loss, num_present, name="value")
def _num_present(losses, weights, per_batch=False):
@ -620,18 +591,19 @@ def mean_pairwise_squared_error(
keepdims=True)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
term1 = 2.0 * _safe_div(
term1 = 2.0 * math_ops.div_no_nan(
sum_squares_diff_per_batch,
math_ops.maximum(num_present_per_batch - 1, 0))
math_ops.maximum(num_present_per_batch - 1, 0),
name="value")
sum_diff = math_ops.reduce_sum(
diffs, reduction_indices=reduction_indices, keepdims=True)
term2 = 2.0 * _safe_div(
term2 = 2.0 * math_ops.div_no_nan(
math_ops.square(sum_diff),
math_ops.maximum(
math_ops.multiply(num_present_per_batch,
num_present_per_batch - 1),
0))
num_present_per_batch - 1), 0),
name="value")
weighted_losses = math_ops.multiply(term1 - term2, weights)
loss = math_ops.reduce_sum(weighted_losses)

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -213,26 +212,6 @@ def _maybe_expand_labels(labels, predictions):
lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
def _safe_div(numerator, denominator, name):
"""Divides two tensors element-wise, returning 0 if the denominator is <= 0.
Args:
numerator: A real `Tensor`.
denominator: A real `Tensor`, with dtype matching `numerator`.
name: Name for the returned op.
Returns:
0 if `denominator` <= 0, else `numerator` / `denominator`
"""
if compat.forward_compatible(2018, 11, 1):
return math_ops.div_no_nan(numerator, denominator, name=name)
t = math_ops.truediv(numerator, denominator)
zero = array_ops.zeros_like(t, dtype=denominator.dtype)
condition = math_ops.greater(denominator, zero)
zero = math_ops.cast(zero, t.dtype)
return array_ops.where(condition, t, zero, name=name)
def _safe_scalar_div(numerator, denominator, name):
"""Divides two values, returning 0 if the denominator is 0.
@ -246,7 +225,7 @@ def _safe_scalar_div(numerator, denominator, name):
"""
numerator.get_shape().with_rank_at_most(1)
denominator.get_shape().with_rank_at_most(1)
return _safe_div(numerator, denominator, name=name)
return math_ops.div_no_nan(numerator, denominator, name=name)
def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
@ -401,13 +380,12 @@ def mean(values,
update_count_op = state_ops.assign_add(count, num_values)
def compute_mean(_, t, c):
return _safe_div(t, math_ops.maximum(c, 0), name='value')
return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')
mean_t = _aggregate_across_replicas(
metrics_collections, compute_mean, total, count)
update_op = _safe_div(update_total_op,
math_ops.maximum(update_count_op, 0),
name='update_op')
update_op = math_ops.div_no_nan(
update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@ -779,19 +757,19 @@ def auc(labels,
"""
dtp = tp[:num_thresholds - 1] - tp[1:]
p = tp + fp
prec_slope = _safe_div(
prec_slope = math_ops.div_no_nan(
dtp,
math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
name='prec_slope')
intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
safe_p_ratio = array_ops.where(
math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
_safe_div(p[:num_thresholds - 1],
math_ops.div_no_nan(
p[:num_thresholds - 1],
math_ops.maximum(p[1:], 0),
name='recall_relative_ratio'),
array_ops.ones_like(p[1:]))
name='recall_relative_ratio'), array_ops.ones_like(p[1:]))
return math_ops.reduce_sum(
_safe_div(
math_ops.div_no_nan(
prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
math_ops.maximum(tp[1:] + fn[1:], 0),
name='pr_auc_increment'),
@ -1074,7 +1052,7 @@ def mean_per_class_accuracy(labels,
update_count_op = state_ops.scatter_add(count, labels, is_correct)
def compute_mean_accuracy(_, count, total):
per_class_accuracy = _safe_div(
per_class_accuracy = math_ops.div_no_nan(
count, math_ops.maximum(total, 0), name=None)
mean_accuracy_v = math_ops.reduce_mean(
per_class_accuracy, name='mean_accuracy')
@ -1083,9 +1061,8 @@ def mean_per_class_accuracy(labels,
mean_accuracy_v = _aggregate_across_replicas(
metrics_collections, compute_mean_accuracy, count, total)
update_op = _safe_div(update_count_op,
math_ops.maximum(update_total_op, 0),
name='update_op')
update_op = math_ops.div_no_nan(
update_count_op, math_ops.maximum(update_total_op, 0), name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@ -1394,15 +1371,14 @@ def mean_tensor(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
compute_mean = lambda _, t, c: _safe_div(
compute_mean = lambda _, t, c: math_ops.div_no_nan( # pylint: disable=g-long-lambda
t, math_ops.maximum(c, 0), name='value')
mean_t = _aggregate_across_replicas(
metrics_collections, compute_mean, total, count)
update_op = _safe_div(update_total_op,
math_ops.maximum(update_count_op, 0),
name='update_op')
update_op = math_ops.div_no_nan(
update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)