Add count metric, a helper function that computes the total number or total weight of examples.

PiperOrigin-RevId: 173731046
This commit is contained in:
A. Unique TensorFlower 2017-10-27 16:11:56 -07:00 committed by TensorFlower Gardener
parent e1d7615ebc
commit 7cb7f88c5f
3 changed files with 231 additions and 9 deletions

View File

@ -65,6 +65,7 @@ See the @{$python/contrib.metrics} guide.
@@set_intersection
@@set_size
@@set_union
@@count
"""
from __future__ import absolute_import
@ -78,6 +79,7 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion
from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
from tensorflow.contrib.metrics.python.ops.metric_ops import count
from tensorflow.contrib.metrics.python.ops.metric_ops import sparse_recall_at_top_k
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc

View File

@ -120,7 +120,7 @@ def _count_condition(values,
or tuple.
"""
check_ops.assert_type(values, dtypes.bool)
count = _create_local('count', shape=[])
count_ = _create_local('count', shape=[])
values = math_ops.to_float(values)
if weights is not None:
@ -128,8 +128,8 @@ def _count_condition(values,
with ops.control_dependencies((_assert_weights_rank(weights, values),)):
values = math_ops.multiply(values, weights)
value_tensor = array_ops.identity(count)
update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
value_tensor = array_ops.identity(count_)
update_op = state_ops.assign_add(count_, math_ops.reduce_sum(values))
if metrics_collections:
ops.add_to_collections(metrics_collections, value_tensor)
@ -2601,7 +2601,7 @@ def streaming_covariance(predictions,
predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions, labels, weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
count = _create_local('count', [])
count_ = _create_local('count', [])
mean_prediction = _create_local('mean_prediction', [])
mean_label = _create_local('mean_label', [])
comoment = _create_local('comoment', []) # C_A in update equation
@ -2616,7 +2616,7 @@ def streaming_covariance(predictions,
weighted_predictions = math_ops.multiply(predictions, weights)
weighted_labels = math_ops.multiply(labels, weights)
update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn
update_count = state_ops.assign_add(count_, batch_count) # n_AB in eqn
prev_count = update_count - batch_count # n_A in update equation
# We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
@ -2660,15 +2660,15 @@ def streaming_covariance(predictions,
update_comoment = state_ops.assign_add(comoment, delta_comoment)
covariance = array_ops.where(
math_ops.less_equal(count, 1.),
math_ops.less_equal(count_, 1.),
float('nan'),
math_ops.truediv(comoment, count - 1),
math_ops.truediv(comoment, count_ - 1),
name='covariance')
with ops.control_dependencies([update_comoment]):
update_op = array_ops.where(
math_ops.less_equal(count, 1.),
math_ops.less_equal(count_, 1.),
float('nan'),
math_ops.truediv(comoment, count - 1),
math_ops.truediv(comoment, count_ - 1),
name='update_op')
if metrics_collections:
@ -3124,9 +3124,71 @@ def aggregate_metric_map(names_to_tuples):
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
def count(values,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
"""Computes the number of examples, or sum of `weights`.
When evaluating some metric (e.g. mean) on one or more subsets of the data,
this auxiliary metric is useful for keeping track of how many examples there
are in each subset.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
values: A `Tensor` of arbitrary dimensions. Only it's shape is used.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions
must be either `1`, or the same as the corresponding `labels`
dimension).
metrics_collections: An optional list of collections that the metric
value variable should be added to.
updates_collections: An optional list of collections that the metric update
ops should be added to.
name: An optional variable_scope name.
Returns:
count: A `Tensor` representing the current value of the metric.
update_op: An operation that accumulates the metric from a batch of data.
Raises:
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
with variable_scope.variable_scope(name, 'count', (values, weights)):
count_ = _create_local('count', shape=[])
if weights is None:
num_values = math_ops.to_float(array_ops.size(values))
else:
_, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
predictions=values,
labels=None,
weights=weights)
weights = weights_broadcast_ops.broadcast_weights(
math_ops.to_float(weights), values)
num_values = math_ops.reduce_sum(weights)
with ops.control_dependencies([values]):
update_op = state_ops.assign_add(count_, num_values)
if metrics_collections:
ops.add_to_collections(metrics_collections, count_)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return count_, update_op
__all__ = [
'aggregate_metric_map',
'aggregate_metrics',
'count',
'sparse_recall_at_top_k',
'streaming_accuracy',
'streaming_auc',

View File

@ -6170,5 +6170,163 @@ class AggregateMetricMapTest(test.TestCase):
self.assertEqual(4, names_to_values['m2'].eval())
class CountTest(test.TestCase):
def setUp(self):
ops.reset_default_graph()
def testVars(self):
metrics.count(array_ops.ones([4, 3]))
_assert_local_variables(self, ['count/count:0'])
def testMetricsCollection(self):
my_collection_name = '__metrics__'
mean, _ = metrics.count(
array_ops.ones([4, 3]), metrics_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [mean])
def testUpdatesCollection(self):
my_collection_name = '__updates__'
_, update_op = metrics.count(
array_ops.ones([4, 3]), updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
with self.test_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
_enqueue_vector(sess, values_queue, [6.5, 0])
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
values = values_queue.dequeue()
result, update_op = metrics.count(values)
sess.run(variables.local_variables_initializer())
for _ in range(4):
sess.run(update_op)
self.assertAlmostEqual(8.0, sess.run(result), 5)
def testUpdateOpsReturnsCurrentValue(self):
with self.test_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
_enqueue_vector(sess, values_queue, [6.5, 0])
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
values = values_queue.dequeue()
result, update_op = metrics.count(values)
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(2.0, sess.run(update_op), 5)
self.assertAlmostEqual(4.0, sess.run(update_op), 5)
self.assertAlmostEqual(6.0, sess.run(update_op), 5)
self.assertAlmostEqual(8.0, sess.run(update_op), 5)
self.assertAlmostEqual(8.0, sess.run(result), 5)
def test1dWeightedValues(self):
with self.test_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
_enqueue_vector(sess, values_queue, [6.5, 0])
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
values = values_queue.dequeue()
# Create the queue that populates the weighted labels.
weights_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
_enqueue_vector(sess, weights_queue, [0.5])
_enqueue_vector(sess, weights_queue, [0])
_enqueue_vector(sess, weights_queue, [0])
_enqueue_vector(sess, weights_queue, [1.2])
weights = weights_queue.dequeue()
result, update_op = metrics.count(values, weights)
variables.local_variables_initializer().run()
for _ in range(4):
update_op.eval()
self.assertAlmostEqual(3.4, result.eval(), 5)
def test1dWeightedValues_placeholders(self):
with self.test_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
# Create the queue that populates the weighted labels.
weights_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1,))
_enqueue_vector(sess, weights_queue, 0.5, shape=(1,))
_enqueue_vector(sess, weights_queue, 0, shape=(1,))
_enqueue_vector(sess, weights_queue, 0, shape=(1,))
_enqueue_vector(sess, weights_queue, 1.2, shape=(1,))
weights = weights_queue.dequeue()
result, update_op = metrics.count(values, weights)
variables.local_variables_initializer().run()
for i in range(4):
update_op.eval(feed_dict={values: feed_values[i]})
self.assertAlmostEqual(3.4, result.eval(), 5)
def test2dWeightedValues(self):
with self.test_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
_enqueue_vector(sess, values_queue, [6.5, 0])
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
values = values_queue.dequeue()
# Create the queue that populates the weighted labels.
weights_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, weights_queue, [1.1, 1])
_enqueue_vector(sess, weights_queue, [1, 0])
_enqueue_vector(sess, weights_queue, [0, 1])
_enqueue_vector(sess, weights_queue, [0, 0])
weights = weights_queue.dequeue()
result, update_op = metrics.count(values, weights)
variables.local_variables_initializer().run()
for _ in range(4):
update_op.eval()
self.assertAlmostEqual(4.1, result.eval(), 5)
def test2dWeightedValues_placeholders(self):
with self.test_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
# Create the queue that populates the weighted labels.
weights_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(2,))
_enqueue_vector(sess, weights_queue, [1.1, 1], shape=(2,))
_enqueue_vector(sess, weights_queue, [1, 0], shape=(2,))
_enqueue_vector(sess, weights_queue, [0, 1], shape=(2,))
_enqueue_vector(sess, weights_queue, [0, 0], shape=(2,))
weights = weights_queue.dequeue()
result, update_op = metrics.count(values, weights)
variables.local_variables_initializer().run()
for i in range(4):
update_op.eval(feed_dict={values: feed_values[i]})
self.assertAlmostEqual(4.1, result.eval(), 5)
if __name__ == '__main__':
test.main()