Set metric variable initializers as lambda.
PiperOrigin-RevId: 174100686
This commit is contained in:
parent
453dd5848f
commit
b242a7988c
@ -92,8 +92,7 @@ def _count_condition(values,
|
||||
or tuple.
|
||||
"""
|
||||
check_ops.assert_type(values, dtypes.bool)
|
||||
count_ = metrics_impl.metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float32), name='count')
|
||||
count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
|
||||
|
||||
values = math_ops.to_float(values)
|
||||
if weights is not None:
|
||||
@ -916,8 +915,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
|
||||
|
||||
if 'tp' in includes:
|
||||
true_positives = metrics_impl.metric_variable(
|
||||
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
|
||||
name='true_positives')
|
||||
[num_thresholds], dtypes.float32, name='true_positives')
|
||||
is_true_positive = math_ops.to_float(
|
||||
math_ops.logical_and(label_is_pos, pred_is_pos))
|
||||
if weights_tiled is not None:
|
||||
@ -929,8 +927,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
|
||||
|
||||
if 'fn' in includes:
|
||||
false_negatives = metrics_impl.metric_variable(
|
||||
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
|
||||
name='false_negatives')
|
||||
[num_thresholds], dtypes.float32, name='false_negatives')
|
||||
is_false_negative = math_ops.to_float(
|
||||
math_ops.logical_and(label_is_pos, pred_is_neg))
|
||||
if weights_tiled is not None:
|
||||
@ -942,8 +939,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
|
||||
|
||||
if 'tn' in includes:
|
||||
true_negatives = metrics_impl.metric_variable(
|
||||
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
|
||||
name='true_negatives')
|
||||
[num_thresholds], dtypes.float32, name='true_negatives')
|
||||
is_true_negative = math_ops.to_float(
|
||||
math_ops.logical_and(label_is_neg, pred_is_neg))
|
||||
if weights_tiled is not None:
|
||||
@ -955,8 +951,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
|
||||
|
||||
if 'fp' in includes:
|
||||
false_positives = metrics_impl.metric_variable(
|
||||
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
|
||||
name='false_positives')
|
||||
[num_thresholds], dtypes.float32, name='false_positives')
|
||||
is_false_positive = math_ops.to_float(
|
||||
math_ops.logical_and(label_is_neg, pred_is_pos))
|
||||
if weights_tiled is not None:
|
||||
@ -1317,9 +1312,9 @@ def streaming_precision_recall_at_equal_thresholds(predictions,
|
||||
|
||||
with ops.name_scope('variables'):
|
||||
tp_buckets_v = metrics_impl.metric_variable(
|
||||
array_ops.zeros([num_thresholds], dtype=dtype), name='tp_buckets')
|
||||
[num_thresholds], dtype, name='tp_buckets')
|
||||
fp_buckets_v = metrics_impl.metric_variable(
|
||||
array_ops.zeros([num_thresholds], dtype=dtype), name='fp_buckets')
|
||||
[num_thresholds], dtype, name='fp_buckets')
|
||||
|
||||
with ops.name_scope('update_op'):
|
||||
update_tp = state_ops.scatter_add(
|
||||
@ -2582,15 +2577,13 @@ def streaming_covariance(predictions,
|
||||
predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
|
||||
predictions, labels, weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
count_ = metrics_impl.metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float32), name='count')
|
||||
count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
|
||||
mean_prediction = metrics_impl.metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float32), name='mean_prediction')
|
||||
[], dtypes.float32, name='mean_prediction')
|
||||
mean_label = metrics_impl.metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float32), name='mean_label')
|
||||
[], dtypes.float32, name='mean_label')
|
||||
comoment = metrics_impl.metric_variable( # C_A in update equation
|
||||
array_ops.zeros([], dtype=dtypes.float32),
|
||||
name='comoment')
|
||||
[], dtypes.float32, name='comoment')
|
||||
|
||||
if weights is None:
|
||||
batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn
|
||||
@ -3011,11 +3004,8 @@ def streaming_concat(values,
|
||||
init_size = 0 if max_size is None else max_size
|
||||
init_shape = [init_size] + fixed_shape
|
||||
array = metrics_impl.metric_variable(
|
||||
array_ops.zeros(init_shape, dtype=values.dtype),
|
||||
validate_shape=False,
|
||||
name='array')
|
||||
size = metrics_impl.metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.int32), name='size')
|
||||
init_shape, values.dtype, validate_shape=False, name='array')
|
||||
size = metrics_impl.metric_variable([], dtypes.int32, name='size')
|
||||
|
||||
perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
|
||||
valid_array = array[:size]
|
||||
@ -3149,8 +3139,7 @@ def count(values,
|
||||
"""
|
||||
|
||||
with variable_scope.variable_scope(name, 'count', (values, weights)):
|
||||
count_ = metrics_impl.metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float32), name='count')
|
||||
count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
|
||||
|
||||
if weights is None:
|
||||
num_values = math_ops.to_float(array_ops.size(values))
|
||||
|
@ -35,18 +35,11 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import weights_broadcast_ops
|
||||
|
||||
|
||||
def metric_variable(initial_value, validate_shape=True, name=None):
|
||||
"""Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections.
|
||||
def metric_variable(shape, dtype, validate_shape=True, name=None):
|
||||
"""Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections."""
|
||||
|
||||
Args:
|
||||
initial_value: See variables.Variable.__init__.
|
||||
validate_shape: See variables.Variable.__init__.
|
||||
name: See variables.Variable.__init__.
|
||||
Returns:
|
||||
New variable.
|
||||
"""
|
||||
return variable_scope.variable(
|
||||
initial_value,
|
||||
lambda: array_ops.zeros(shape, dtype),
|
||||
trainable=False,
|
||||
collections=[
|
||||
ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
|
||||
@ -244,8 +237,7 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
|
||||
"""
|
||||
# Local variable to accumulate the predictions in the confusion matrix.
|
||||
total_cm = metric_variable(
|
||||
array_ops.zeros([num_classes, num_classes], dtype=dtypes.float64),
|
||||
name='total_confusion_matrix')
|
||||
[num_classes, num_classes], dtypes.float64, name='total_confusion_matrix')
|
||||
|
||||
# Cast the type to int64 required by confusion_matrix_ops.
|
||||
predictions = math_ops.to_int64(predictions)
|
||||
@ -315,10 +307,8 @@ def mean(values, weights=None, metrics_collections=None,
|
||||
with variable_scope.variable_scope(name, 'mean', (values, weights)):
|
||||
values = math_ops.to_float(values)
|
||||
|
||||
total = metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float32), name='total')
|
||||
count = metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float32), name='count')
|
||||
total = metric_variable([], dtypes.float32, name='total')
|
||||
count = metric_variable([], dtypes.float32, name='count')
|
||||
|
||||
if weights is None:
|
||||
num_values = math_ops.to_float(array_ops.size(values))
|
||||
@ -516,8 +506,7 @@ def _confusion_matrix_at_thresholds(
|
||||
|
||||
if 'tp' in includes:
|
||||
true_p = metric_variable(
|
||||
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
|
||||
name='true_positives')
|
||||
[num_thresholds], dtypes.float32, name='true_positives')
|
||||
is_true_positive = math_ops.to_float(
|
||||
math_ops.logical_and(label_is_pos, pred_is_pos))
|
||||
if weights_tiled is not None:
|
||||
@ -528,8 +517,7 @@ def _confusion_matrix_at_thresholds(
|
||||
|
||||
if 'fn' in includes:
|
||||
false_n = metric_variable(
|
||||
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
|
||||
name='false_negatives')
|
||||
[num_thresholds], dtypes.float32, name='false_negatives')
|
||||
is_false_negative = math_ops.to_float(
|
||||
math_ops.logical_and(label_is_pos, pred_is_neg))
|
||||
if weights_tiled is not None:
|
||||
@ -540,8 +528,7 @@ def _confusion_matrix_at_thresholds(
|
||||
|
||||
if 'tn' in includes:
|
||||
true_n = metric_variable(
|
||||
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
|
||||
name='true_negatives')
|
||||
[num_thresholds], dtypes.float32, name='true_negatives')
|
||||
is_true_negative = math_ops.to_float(
|
||||
math_ops.logical_and(label_is_neg, pred_is_neg))
|
||||
if weights_tiled is not None:
|
||||
@ -552,8 +539,7 @@ def _confusion_matrix_at_thresholds(
|
||||
|
||||
if 'fp' in includes:
|
||||
false_p = metric_variable(
|
||||
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
|
||||
name='false_positives')
|
||||
[num_thresholds], dtypes.float32, name='false_positives')
|
||||
is_false_positive = math_ops.to_float(
|
||||
math_ops.logical_and(label_is_neg, pred_is_pos))
|
||||
if weights_tiled is not None:
|
||||
@ -1183,11 +1169,9 @@ def mean_tensor(values, weights=None, metrics_collections=None,
|
||||
with variable_scope.variable_scope(name, 'mean', (values, weights)):
|
||||
values = math_ops.to_float(values)
|
||||
total = metric_variable(
|
||||
array_ops.zeros(values.get_shape(), dtype=dtypes.float32),
|
||||
name='total_tensor')
|
||||
values.get_shape(), dtypes.float32, name='total_tensor')
|
||||
count = metric_variable(
|
||||
array_ops.zeros(values.get_shape(), dtype=dtypes.float32),
|
||||
name='count_tensor')
|
||||
values.get_shape(), dtypes.float32, name='count_tensor')
|
||||
|
||||
num_values = array_ops.ones_like(values)
|
||||
if weights is not None:
|
||||
@ -1300,8 +1284,7 @@ def _count_condition(values, weights=None, metrics_collections=None,
|
||||
or tuple.
|
||||
"""
|
||||
check_ops.assert_type(values, dtypes.bool)
|
||||
count = metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float32), name='count')
|
||||
count = metric_variable([], dtypes.float32, name='count')
|
||||
|
||||
values = math_ops.to_float(values)
|
||||
if weights is not None:
|
||||
@ -2082,7 +2065,7 @@ def _streaming_sparse_true_positive_at_k(labels,
|
||||
weights=weights)
|
||||
batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
|
||||
|
||||
var = metric_variable(array_ops.zeros([], dtype=dtypes.float64), name=scope)
|
||||
var = metric_variable([], dtypes.float64, name=scope)
|
||||
return var, state_ops.assign_add(var, batch_total_tp, name='update')
|
||||
|
||||
|
||||
@ -2178,7 +2161,7 @@ def _streaming_sparse_false_negative_at_k(labels,
|
||||
weights=weights)
|
||||
batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
|
||||
|
||||
var = metric_variable(array_ops.zeros([], dtype=dtypes.float64), name=scope)
|
||||
var = metric_variable([], dtypes.float64, name=scope)
|
||||
return var, state_ops.assign_add(var, batch_total_fn, name='update')
|
||||
|
||||
|
||||
@ -2829,8 +2812,7 @@ def _streaming_sparse_average_precision_at_top_k(labels,
|
||||
# - For the unweighted case, this is just the number of rows.
|
||||
# - For the weighted case, it's the sum of the weights broadcast across
|
||||
# `average_precision` rows.
|
||||
max_var = metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float64), name=max_scope)
|
||||
max_var = metric_variable([], dtypes.float64, name=max_scope)
|
||||
if weights is None:
|
||||
batch_max = math_ops.to_double(
|
||||
array_ops.size(average_precision, name='batch_max'))
|
||||
@ -2838,8 +2820,7 @@ def _streaming_sparse_average_precision_at_top_k(labels,
|
||||
batch_max = math_ops.reduce_sum(weights, name='batch_max')
|
||||
max_update = state_ops.assign_add(max_var, batch_max, name='update')
|
||||
with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
|
||||
total_var = metric_variable(
|
||||
array_ops.zeros([], dtype=dtypes.float64), name=total_scope)
|
||||
total_var = metric_variable([], dtypes.float64, name=total_scope)
|
||||
batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
|
||||
total_update = state_ops.assign_add(total_var, batch_total, name='update')
|
||||
|
||||
@ -3025,7 +3006,7 @@ def _streaming_sparse_false_positive_at_k(labels,
|
||||
weights=weights)
|
||||
batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
|
||||
|
||||
var = metric_variable(array_ops.zeros([], dtype=dtypes.float64), name=scope)
|
||||
var = metric_variable([], dtypes.float64, name=scope)
|
||||
return var, state_ops.assign_add(var, batch_total_fp, name='update')
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user