Set metric variable initializers as lambda.

PiperOrigin-RevId: 174100686
This commit is contained in:
Mustafa Ispir 2017-10-31 15:20:08 -07:00 committed by TensorFlower Gardener
parent 453dd5848f
commit b242a7988c
2 changed files with 32 additions and 62 deletions

View File

@ -92,8 +92,7 @@ def _count_condition(values,
or tuple.
"""
check_ops.assert_type(values, dtypes.bool)
count_ = metrics_impl.metric_variable(
array_ops.zeros([], dtype=dtypes.float32), name='count')
count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
values = math_ops.to_float(values)
if weights is not None:
@ -916,8 +915,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
if 'tp' in includes:
true_positives = metrics_impl.metric_variable(
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
name='true_positives')
[num_thresholds], dtypes.float32, name='true_positives')
is_true_positive = math_ops.to_float(
math_ops.logical_and(label_is_pos, pred_is_pos))
if weights_tiled is not None:
@ -929,8 +927,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
if 'fn' in includes:
false_negatives = metrics_impl.metric_variable(
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
name='false_negatives')
[num_thresholds], dtypes.float32, name='false_negatives')
is_false_negative = math_ops.to_float(
math_ops.logical_and(label_is_pos, pred_is_neg))
if weights_tiled is not None:
@ -942,8 +939,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
if 'tn' in includes:
true_negatives = metrics_impl.metric_variable(
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
name='true_negatives')
[num_thresholds], dtypes.float32, name='true_negatives')
is_true_negative = math_ops.to_float(
math_ops.logical_and(label_is_neg, pred_is_neg))
if weights_tiled is not None:
@ -955,8 +951,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions,
if 'fp' in includes:
false_positives = metrics_impl.metric_variable(
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
name='false_positives')
[num_thresholds], dtypes.float32, name='false_positives')
is_false_positive = math_ops.to_float(
math_ops.logical_and(label_is_neg, pred_is_pos))
if weights_tiled is not None:
@ -1317,9 +1312,9 @@ def streaming_precision_recall_at_equal_thresholds(predictions,
with ops.name_scope('variables'):
tp_buckets_v = metrics_impl.metric_variable(
array_ops.zeros([num_thresholds], dtype=dtype), name='tp_buckets')
[num_thresholds], dtype, name='tp_buckets')
fp_buckets_v = metrics_impl.metric_variable(
array_ops.zeros([num_thresholds], dtype=dtype), name='fp_buckets')
[num_thresholds], dtype, name='fp_buckets')
with ops.name_scope('update_op'):
update_tp = state_ops.scatter_add(
@ -2582,15 +2577,13 @@ def streaming_covariance(predictions,
predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
count_ = metrics_impl.metric_variable(
array_ops.zeros([], dtype=dtypes.float32), name='count')
count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
mean_prediction = metrics_impl.metric_variable(
array_ops.zeros([], dtype=dtypes.float32), name='mean_prediction')
[], dtypes.float32, name='mean_prediction')
mean_label = metrics_impl.metric_variable(
array_ops.zeros([], dtype=dtypes.float32), name='mean_label')
[], dtypes.float32, name='mean_label')
comoment = metrics_impl.metric_variable( # C_A in update equation
array_ops.zeros([], dtype=dtypes.float32),
name='comoment')
[], dtypes.float32, name='comoment')
if weights is None:
batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn
@ -3011,11 +3004,8 @@ def streaming_concat(values,
init_size = 0 if max_size is None else max_size
init_shape = [init_size] + fixed_shape
array = metrics_impl.metric_variable(
array_ops.zeros(init_shape, dtype=values.dtype),
validate_shape=False,
name='array')
size = metrics_impl.metric_variable(
array_ops.zeros([], dtype=dtypes.int32), name='size')
init_shape, values.dtype, validate_shape=False, name='array')
size = metrics_impl.metric_variable([], dtypes.int32, name='size')
perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)]
valid_array = array[:size]
@ -3149,8 +3139,7 @@ def count(values,
"""
with variable_scope.variable_scope(name, 'count', (values, weights)):
count_ = metrics_impl.metric_variable(
array_ops.zeros([], dtype=dtypes.float32), name='count')
count_ = metrics_impl.metric_variable([], dtypes.float32, name='count')
if weights is None:
num_values = math_ops.to_float(array_ops.size(values))

View File

@ -35,18 +35,11 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
def metric_variable(initial_value, validate_shape=True, name=None):
"""Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections.
def metric_variable(shape, dtype, validate_shape=True, name=None):
"""Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections."""
Args:
initial_value: See variables.Variable.__init__.
validate_shape: See variables.Variable.__init__.
name: See variables.Variable.__init__.
Returns:
New variable.
"""
return variable_scope.variable(
initial_value,
lambda: array_ops.zeros(shape, dtype),
trainable=False,
collections=[
ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
@ -244,8 +237,7 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
"""
# Local variable to accumulate the predictions in the confusion matrix.
total_cm = metric_variable(
array_ops.zeros([num_classes, num_classes], dtype=dtypes.float64),
name='total_confusion_matrix')
[num_classes, num_classes], dtypes.float64, name='total_confusion_matrix')
# Cast the type to int64 required by confusion_matrix_ops.
predictions = math_ops.to_int64(predictions)
@ -315,10 +307,8 @@ def mean(values, weights=None, metrics_collections=None,
with variable_scope.variable_scope(name, 'mean', (values, weights)):
values = math_ops.to_float(values)
total = metric_variable(
array_ops.zeros([], dtype=dtypes.float32), name='total')
count = metric_variable(
array_ops.zeros([], dtype=dtypes.float32), name='count')
total = metric_variable([], dtypes.float32, name='total')
count = metric_variable([], dtypes.float32, name='count')
if weights is None:
num_values = math_ops.to_float(array_ops.size(values))
@ -516,8 +506,7 @@ def _confusion_matrix_at_thresholds(
if 'tp' in includes:
true_p = metric_variable(
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
name='true_positives')
[num_thresholds], dtypes.float32, name='true_positives')
is_true_positive = math_ops.to_float(
math_ops.logical_and(label_is_pos, pred_is_pos))
if weights_tiled is not None:
@ -528,8 +517,7 @@ def _confusion_matrix_at_thresholds(
if 'fn' in includes:
false_n = metric_variable(
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
name='false_negatives')
[num_thresholds], dtypes.float32, name='false_negatives')
is_false_negative = math_ops.to_float(
math_ops.logical_and(label_is_pos, pred_is_neg))
if weights_tiled is not None:
@ -540,8 +528,7 @@ def _confusion_matrix_at_thresholds(
if 'tn' in includes:
true_n = metric_variable(
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
name='true_negatives')
[num_thresholds], dtypes.float32, name='true_negatives')
is_true_negative = math_ops.to_float(
math_ops.logical_and(label_is_neg, pred_is_neg))
if weights_tiled is not None:
@ -552,8 +539,7 @@ def _confusion_matrix_at_thresholds(
if 'fp' in includes:
false_p = metric_variable(
array_ops.zeros([num_thresholds], dtype=dtypes.float32),
name='false_positives')
[num_thresholds], dtypes.float32, name='false_positives')
is_false_positive = math_ops.to_float(
math_ops.logical_and(label_is_neg, pred_is_pos))
if weights_tiled is not None:
@ -1183,11 +1169,9 @@ def mean_tensor(values, weights=None, metrics_collections=None,
with variable_scope.variable_scope(name, 'mean', (values, weights)):
values = math_ops.to_float(values)
total = metric_variable(
array_ops.zeros(values.get_shape(), dtype=dtypes.float32),
name='total_tensor')
values.get_shape(), dtypes.float32, name='total_tensor')
count = metric_variable(
array_ops.zeros(values.get_shape(), dtype=dtypes.float32),
name='count_tensor')
values.get_shape(), dtypes.float32, name='count_tensor')
num_values = array_ops.ones_like(values)
if weights is not None:
@ -1300,8 +1284,7 @@ def _count_condition(values, weights=None, metrics_collections=None,
or tuple.
"""
check_ops.assert_type(values, dtypes.bool)
count = metric_variable(
array_ops.zeros([], dtype=dtypes.float32), name='count')
count = metric_variable([], dtypes.float32, name='count')
values = math_ops.to_float(values)
if weights is not None:
@ -2082,7 +2065,7 @@ def _streaming_sparse_true_positive_at_k(labels,
weights=weights)
batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
var = metric_variable(array_ops.zeros([], dtype=dtypes.float64), name=scope)
var = metric_variable([], dtypes.float64, name=scope)
return var, state_ops.assign_add(var, batch_total_tp, name='update')
@ -2178,7 +2161,7 @@ def _streaming_sparse_false_negative_at_k(labels,
weights=weights)
batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
var = metric_variable(array_ops.zeros([], dtype=dtypes.float64), name=scope)
var = metric_variable([], dtypes.float64, name=scope)
return var, state_ops.assign_add(var, batch_total_fn, name='update')
@ -2829,8 +2812,7 @@ def _streaming_sparse_average_precision_at_top_k(labels,
# - For the unweighted case, this is just the number of rows.
# - For the weighted case, it's the sum of the weights broadcast across
# `average_precision` rows.
max_var = metric_variable(
array_ops.zeros([], dtype=dtypes.float64), name=max_scope)
max_var = metric_variable([], dtypes.float64, name=max_scope)
if weights is None:
batch_max = math_ops.to_double(
array_ops.size(average_precision, name='batch_max'))
@ -2838,8 +2820,7 @@ def _streaming_sparse_average_precision_at_top_k(labels,
batch_max = math_ops.reduce_sum(weights, name='batch_max')
max_update = state_ops.assign_add(max_var, batch_max, name='update')
with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
total_var = metric_variable(
array_ops.zeros([], dtype=dtypes.float64), name=total_scope)
total_var = metric_variable([], dtypes.float64, name=total_scope)
batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
total_update = state_ops.assign_add(total_var, batch_total, name='update')
@ -3025,7 +3006,7 @@ def _streaming_sparse_false_positive_at_k(labels,
weights=weights)
batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
var = metric_variable(array_ops.zeros([], dtype=dtypes.float64), name=scope)
var = metric_variable([], dtypes.float64, name=scope)
return var, state_ops.assign_add(var, batch_total_fp, name='update')