Cleaning up safe_div implementations
PiperOrigin-RevId: 222124095
This commit is contained in:
parent
9a83d2111f
commit
c5695df38c
@ -22,7 +22,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import add_arg_scope
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -67,34 +66,6 @@ def _scale_losses(losses, weights):
|
||||
return math_ops.reduce_sum(reduced_losses)
|
||||
|
||||
|
||||
def _safe_div(numerator, denominator, name="value"):
|
||||
"""Computes a safe divide which returns 0 if the denominator is zero.
|
||||
|
||||
Note that the function contains an additional conditional check that is
|
||||
necessary for avoiding situations where the loss is zero causing NaNs to
|
||||
creep into the gradient computation.
|
||||
|
||||
Args:
|
||||
numerator: An arbitrary `Tensor`.
|
||||
denominator: A `Tensor` whose shape matches `numerator` and whose values are
|
||||
assumed to be non-negative.
|
||||
name: An optional name for the returned op.
|
||||
|
||||
Returns:
|
||||
The element-wise value of the numerator divided by the denominator.
|
||||
"""
|
||||
if compat.forward_compatible(2018, 11, 1):
|
||||
return math_ops.div_no_nan(numerator, denominator, name=name)
|
||||
return array_ops.where(
|
||||
math_ops.greater(denominator, 0),
|
||||
math_ops.div(numerator,
|
||||
array_ops.where(
|
||||
math_ops.equal(denominator, 0),
|
||||
array_ops.ones_like(denominator), denominator)),
|
||||
array_ops.zeros_like(numerator),
|
||||
name=name)
|
||||
|
||||
|
||||
def _safe_mean(losses, num_present):
|
||||
"""Computes a safe mean of the losses.
|
||||
|
||||
@ -107,7 +78,7 @@ def _safe_mean(losses, num_present):
|
||||
then zero is returned.
|
||||
"""
|
||||
total_loss = math_ops.reduce_sum(losses)
|
||||
return _safe_div(total_loss, num_present, name="value")
|
||||
return math_ops.div_no_nan(total_loss, num_present, name="value")
|
||||
|
||||
|
||||
@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
|
||||
@ -612,14 +583,14 @@ def mean_pairwise_squared_error(predictions,
|
||||
math_ops.square(diffs), reduction_indices=reduction_indices)
|
||||
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
|
||||
|
||||
term1 = 2.0 * _safe_div(sum_squares_diff_per_batch,
|
||||
num_present_per_batch,
|
||||
name="value")
|
||||
term1 = 2.0 * math_ops.div_no_nan(
|
||||
sum_squares_diff_per_batch, num_present_per_batch, name="value")
|
||||
|
||||
sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
|
||||
term2 = 2.0 * _safe_div(math_ops.square(sum_diff),
|
||||
math_ops.square(num_present_per_batch),
|
||||
name="value")
|
||||
term2 = 2.0 * math_ops.div_no_nan(
|
||||
math_ops.square(sum_diff),
|
||||
math_ops.square(num_present_per_batch),
|
||||
name="value")
|
||||
|
||||
loss = _scale_losses(term1 - term2, weights)
|
||||
|
||||
|
@ -24,7 +24,6 @@ from __future__ import print_function
|
||||
|
||||
import collections as collections_lib
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -46,32 +45,6 @@ from tensorflow.python.util.deprecation import deprecated
|
||||
_EPSILON = 1e-7
|
||||
|
||||
|
||||
def _safe_div(numerator, denominator):
|
||||
"""Computes a safe divide which returns 0 if the denominator is zero.
|
||||
|
||||
Note that the function contains an additional conditional check that is
|
||||
necessary for avoiding situations where the loss is zero causing NaNs to
|
||||
creep into the gradient computation.
|
||||
|
||||
Args:
|
||||
numerator: An arbitrary `Tensor`.
|
||||
denominator: A `Tensor` whose shape matches `numerator` and whose values are
|
||||
assumed to be non-negative.
|
||||
|
||||
Returns:
|
||||
The element-wise value of the numerator divided by the denominator.
|
||||
"""
|
||||
if compat.forward_compatible(2018, 11, 1):
|
||||
return math_ops.div_no_nan(numerator, denominator)
|
||||
return array_ops.where(
|
||||
math_ops.greater(denominator, 0),
|
||||
math_ops.div(numerator,
|
||||
array_ops.where(
|
||||
math_ops.equal(denominator, 0),
|
||||
array_ops.ones_like(denominator), denominator)),
|
||||
array_ops.zeros_like(numerator))
|
||||
|
||||
|
||||
@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the '
|
||||
'order of the labels and predictions arguments has been switched.')
|
||||
def streaming_true_positives(predictions,
|
||||
@ -3247,24 +3220,20 @@ def streaming_covariance(predictions,
|
||||
|
||||
# We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
|
||||
# batch_mean_prediction is E[x_B] in the update equation
|
||||
batch_mean_prediction = _safe_div(
|
||||
math_ops.reduce_sum(weighted_predictions),
|
||||
batch_count)
|
||||
delta_mean_prediction = _safe_div(
|
||||
(batch_mean_prediction - mean_prediction) * batch_count,
|
||||
update_count)
|
||||
batch_mean_prediction = math_ops.div_no_nan(
|
||||
math_ops.reduce_sum(weighted_predictions), batch_count)
|
||||
delta_mean_prediction = math_ops.div_no_nan(
|
||||
(batch_mean_prediction - mean_prediction) * batch_count, update_count)
|
||||
update_mean_prediction = state_ops.assign_add(mean_prediction,
|
||||
delta_mean_prediction)
|
||||
# prev_mean_prediction is E[x_A] in the update equation
|
||||
prev_mean_prediction = update_mean_prediction - delta_mean_prediction
|
||||
|
||||
# batch_mean_label is E[y_B] in the update equation
|
||||
batch_mean_label = _safe_div(
|
||||
math_ops.reduce_sum(weighted_labels),
|
||||
batch_count)
|
||||
delta_mean_label = _safe_div(
|
||||
(batch_mean_label - mean_label) * batch_count,
|
||||
update_count)
|
||||
batch_mean_label = math_ops.div_no_nan(
|
||||
math_ops.reduce_sum(weighted_labels), batch_count)
|
||||
delta_mean_label = math_ops.div_no_nan(
|
||||
(batch_mean_label - mean_label) * batch_count, update_count)
|
||||
update_mean_label = state_ops.assign_add(mean_label, delta_mean_label)
|
||||
# prev_mean_label is E[y_A] in the update equation
|
||||
prev_mean_label = update_mean_label - delta_mean_label
|
||||
@ -3926,9 +3895,8 @@ def cohen_kappa(labels,
|
||||
po_sum = math_ops.reduce_sum(po)
|
||||
total = math_ops.reduce_sum(pe_row)
|
||||
pe_sum = math_ops.reduce_sum(
|
||||
_safe_div(
|
||||
math_ops.to_double(pe_row * pe_col),
|
||||
math_ops.to_double(total)))
|
||||
math_ops.div_no_nan(
|
||||
math_ops.to_double(pe_row * pe_col), math_ops.to_double(total)))
|
||||
po_sum, pe_sum, total = (math_ops.to_double(po_sum),
|
||||
math_ops.to_double(pe_sum),
|
||||
math_ops.to_double(total))
|
||||
|
@ -654,7 +654,7 @@ def weighted_masked_objective(fn):
|
||||
score_array = math_ops.multiply(score_array, weights)
|
||||
score_array = math_ops.reduce_sum(score_array)
|
||||
weights = math_ops.reduce_sum(weights)
|
||||
score_array = metrics_module.safe_div(score_array, weights)
|
||||
score_array = math_ops.div_no_nan(score_array, weights)
|
||||
return K.mean(score_array)
|
||||
|
||||
return weighted
|
||||
|
@ -27,7 +27,6 @@ import weakref
|
||||
from enum import Enum
|
||||
import six
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -173,32 +172,6 @@ def weakmethod(method):
|
||||
return inner
|
||||
|
||||
|
||||
def safe_div(numerator, denominator):
|
||||
"""Computes a safe divide which returns 0 if the denominator is zero.
|
||||
|
||||
Note that the function contains an additional conditional check that is
|
||||
necessary for avoiding situations where the loss is zero causing NaNs to
|
||||
creep into the gradient computation.
|
||||
|
||||
Args:
|
||||
numerator: An arbitrary `Tensor`.
|
||||
denominator: A `Tensor` whose shape matches `numerator` and whose values are
|
||||
assumed to be non-negative.
|
||||
|
||||
Returns:
|
||||
The element-wise value of the numerator divided by the denominator.
|
||||
"""
|
||||
if compat.forward_compatible(2018, 11, 1):
|
||||
return math_ops.div_no_nan(numerator, denominator)
|
||||
return array_ops.where(
|
||||
math_ops.greater(denominator, 0),
|
||||
math_ops.div(numerator,
|
||||
array_ops.where(
|
||||
math_ops.equal(denominator, 0),
|
||||
array_ops.ones_like(denominator), denominator)),
|
||||
array_ops.zeros_like(numerator))
|
||||
|
||||
|
||||
def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
|
||||
"""Squeeze or expand last dimension if needed.
|
||||
|
||||
@ -697,7 +670,7 @@ class Mean(Metric):
|
||||
return ops.convert_to_tensor(update_count_op)
|
||||
|
||||
def result(self):
|
||||
return safe_div(self.total, self.count)
|
||||
return math_ops.div_no_nan(self.total, self.count)
|
||||
|
||||
|
||||
class MeanMetricWrapper(Mean):
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -92,34 +91,6 @@ class Reduction(ReductionV2):
|
||||
raise ValueError("Invalid Reduction Key %s." % key)
|
||||
|
||||
|
||||
def _safe_div(numerator, denominator, name="value"):
|
||||
"""Computes a safe divide which returns 0 if the denominator is zero.
|
||||
|
||||
Note that the function contains an additional conditional check that is
|
||||
necessary for avoiding situations where the loss is zero causing NaNs to
|
||||
creep into the gradient computation.
|
||||
|
||||
Args:
|
||||
numerator: An arbitrary `Tensor`.
|
||||
denominator: A `Tensor` whose shape matches `numerator` and whose values are
|
||||
assumed to be non-negative.
|
||||
name: An optional name for the returned op.
|
||||
|
||||
Returns:
|
||||
The element-wise value of the numerator divided by the denominator.
|
||||
"""
|
||||
if compat.forward_compatible(2018, 11, 1):
|
||||
return math_ops.div_no_nan(numerator, denominator, name=name)
|
||||
return array_ops.where(
|
||||
math_ops.greater(denominator, 0),
|
||||
math_ops.div(numerator,
|
||||
array_ops.where(
|
||||
math_ops.equal(denominator, 0),
|
||||
array_ops.ones_like(denominator), denominator)),
|
||||
array_ops.zeros_like(numerator),
|
||||
name=name)
|
||||
|
||||
|
||||
def _safe_mean(losses, num_present):
|
||||
"""Computes a safe mean of the losses.
|
||||
|
||||
@ -132,7 +103,7 @@ def _safe_mean(losses, num_present):
|
||||
then zero is returned.
|
||||
"""
|
||||
total_loss = math_ops.reduce_sum(losses)
|
||||
return _safe_div(total_loss, num_present)
|
||||
return math_ops.div_no_nan(total_loss, num_present, name="value")
|
||||
|
||||
|
||||
def _num_present(losses, weights, per_batch=False):
|
||||
@ -620,18 +591,19 @@ def mean_pairwise_squared_error(
|
||||
keepdims=True)
|
||||
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
|
||||
|
||||
term1 = 2.0 * _safe_div(
|
||||
term1 = 2.0 * math_ops.div_no_nan(
|
||||
sum_squares_diff_per_batch,
|
||||
math_ops.maximum(num_present_per_batch - 1, 0))
|
||||
math_ops.maximum(num_present_per_batch - 1, 0),
|
||||
name="value")
|
||||
|
||||
sum_diff = math_ops.reduce_sum(
|
||||
diffs, reduction_indices=reduction_indices, keepdims=True)
|
||||
term2 = 2.0 * _safe_div(
|
||||
term2 = 2.0 * math_ops.div_no_nan(
|
||||
math_ops.square(sum_diff),
|
||||
math_ops.maximum(
|
||||
math_ops.multiply(num_present_per_batch,
|
||||
num_present_per_batch - 1),
|
||||
0))
|
||||
num_present_per_batch - 1), 0),
|
||||
name="value")
|
||||
|
||||
weighted_losses = math_ops.multiply(term1 - term2, weights)
|
||||
loss = math_ops.reduce_sum(weighted_losses)
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -213,26 +212,6 @@ def _maybe_expand_labels(labels, predictions):
|
||||
lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
|
||||
|
||||
|
||||
def _safe_div(numerator, denominator, name):
|
||||
"""Divides two tensors element-wise, returning 0 if the denominator is <= 0.
|
||||
|
||||
Args:
|
||||
numerator: A real `Tensor`.
|
||||
denominator: A real `Tensor`, with dtype matching `numerator`.
|
||||
name: Name for the returned op.
|
||||
|
||||
Returns:
|
||||
0 if `denominator` <= 0, else `numerator` / `denominator`
|
||||
"""
|
||||
if compat.forward_compatible(2018, 11, 1):
|
||||
return math_ops.div_no_nan(numerator, denominator, name=name)
|
||||
t = math_ops.truediv(numerator, denominator)
|
||||
zero = array_ops.zeros_like(t, dtype=denominator.dtype)
|
||||
condition = math_ops.greater(denominator, zero)
|
||||
zero = math_ops.cast(zero, t.dtype)
|
||||
return array_ops.where(condition, t, zero, name=name)
|
||||
|
||||
|
||||
def _safe_scalar_div(numerator, denominator, name):
|
||||
"""Divides two values, returning 0 if the denominator is 0.
|
||||
|
||||
@ -246,7 +225,7 @@ def _safe_scalar_div(numerator, denominator, name):
|
||||
"""
|
||||
numerator.get_shape().with_rank_at_most(1)
|
||||
denominator.get_shape().with_rank_at_most(1)
|
||||
return _safe_div(numerator, denominator, name=name)
|
||||
return math_ops.div_no_nan(numerator, denominator, name=name)
|
||||
|
||||
|
||||
def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
|
||||
@ -401,13 +380,12 @@ def mean(values,
|
||||
update_count_op = state_ops.assign_add(count, num_values)
|
||||
|
||||
def compute_mean(_, t, c):
|
||||
return _safe_div(t, math_ops.maximum(c, 0), name='value')
|
||||
return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')
|
||||
|
||||
mean_t = _aggregate_across_replicas(
|
||||
metrics_collections, compute_mean, total, count)
|
||||
update_op = _safe_div(update_total_op,
|
||||
math_ops.maximum(update_count_op, 0),
|
||||
name='update_op')
|
||||
update_op = math_ops.div_no_nan(
|
||||
update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
@ -779,19 +757,19 @@ def auc(labels,
|
||||
"""
|
||||
dtp = tp[:num_thresholds - 1] - tp[1:]
|
||||
p = tp + fp
|
||||
prec_slope = _safe_div(
|
||||
prec_slope = math_ops.div_no_nan(
|
||||
dtp,
|
||||
math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
|
||||
name='prec_slope')
|
||||
intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
|
||||
safe_p_ratio = array_ops.where(
|
||||
math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
|
||||
_safe_div(p[:num_thresholds - 1],
|
||||
math_ops.maximum(p[1:], 0),
|
||||
name='recall_relative_ratio'),
|
||||
array_ops.ones_like(p[1:]))
|
||||
math_ops.div_no_nan(
|
||||
p[:num_thresholds - 1],
|
||||
math_ops.maximum(p[1:], 0),
|
||||
name='recall_relative_ratio'), array_ops.ones_like(p[1:]))
|
||||
return math_ops.reduce_sum(
|
||||
_safe_div(
|
||||
math_ops.div_no_nan(
|
||||
prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
|
||||
math_ops.maximum(tp[1:] + fn[1:], 0),
|
||||
name='pr_auc_increment'),
|
||||
@ -1074,7 +1052,7 @@ def mean_per_class_accuracy(labels,
|
||||
update_count_op = state_ops.scatter_add(count, labels, is_correct)
|
||||
|
||||
def compute_mean_accuracy(_, count, total):
|
||||
per_class_accuracy = _safe_div(
|
||||
per_class_accuracy = math_ops.div_no_nan(
|
||||
count, math_ops.maximum(total, 0), name=None)
|
||||
mean_accuracy_v = math_ops.reduce_mean(
|
||||
per_class_accuracy, name='mean_accuracy')
|
||||
@ -1083,9 +1061,8 @@ def mean_per_class_accuracy(labels,
|
||||
mean_accuracy_v = _aggregate_across_replicas(
|
||||
metrics_collections, compute_mean_accuracy, count, total)
|
||||
|
||||
update_op = _safe_div(update_count_op,
|
||||
math_ops.maximum(update_total_op, 0),
|
||||
name='update_op')
|
||||
update_op = math_ops.div_no_nan(
|
||||
update_count_op, math_ops.maximum(update_total_op, 0), name='update_op')
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
@ -1394,15 +1371,14 @@ def mean_tensor(values,
|
||||
with ops.control_dependencies([values]):
|
||||
update_count_op = state_ops.assign_add(count, num_values)
|
||||
|
||||
compute_mean = lambda _, t, c: _safe_div(
|
||||
compute_mean = lambda _, t, c: math_ops.div_no_nan( # pylint: disable=g-long-lambda
|
||||
t, math_ops.maximum(c, 0), name='value')
|
||||
|
||||
mean_t = _aggregate_across_replicas(
|
||||
metrics_collections, compute_mean, total, count)
|
||||
|
||||
update_op = _safe_div(update_total_op,
|
||||
math_ops.maximum(update_count_op, 0),
|
||||
name='update_op')
|
||||
update_op = math_ops.div_no_nan(
|
||||
update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user