Add count metric, a helper function that computes the total number or total weight of examples.
PiperOrigin-RevId: 173731046
This commit is contained in:
parent
e1d7615ebc
commit
7cb7f88c5f
@ -65,6 +65,7 @@ See the @{$python/contrib.metrics} guide.
|
||||
@@set_intersection
|
||||
@@set_size
|
||||
@@set_union
|
||||
@@count
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
@ -78,6 +79,7 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion
|
||||
from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import count
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import sparse_recall_at_top_k
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
|
||||
|
@ -120,7 +120,7 @@ def _count_condition(values,
|
||||
or tuple.
|
||||
"""
|
||||
check_ops.assert_type(values, dtypes.bool)
|
||||
count = _create_local('count', shape=[])
|
||||
count_ = _create_local('count', shape=[])
|
||||
|
||||
values = math_ops.to_float(values)
|
||||
if weights is not None:
|
||||
@ -128,8 +128,8 @@ def _count_condition(values,
|
||||
with ops.control_dependencies((_assert_weights_rank(weights, values),)):
|
||||
values = math_ops.multiply(values, weights)
|
||||
|
||||
value_tensor = array_ops.identity(count)
|
||||
update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
|
||||
value_tensor = array_ops.identity(count_)
|
||||
update_op = state_ops.assign_add(count_, math_ops.reduce_sum(values))
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, value_tensor)
|
||||
@ -2601,7 +2601,7 @@ def streaming_covariance(predictions,
|
||||
predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
|
||||
predictions, labels, weights)
|
||||
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
|
||||
count = _create_local('count', [])
|
||||
count_ = _create_local('count', [])
|
||||
mean_prediction = _create_local('mean_prediction', [])
|
||||
mean_label = _create_local('mean_label', [])
|
||||
comoment = _create_local('comoment', []) # C_A in update equation
|
||||
@ -2616,7 +2616,7 @@ def streaming_covariance(predictions,
|
||||
weighted_predictions = math_ops.multiply(predictions, weights)
|
||||
weighted_labels = math_ops.multiply(labels, weights)
|
||||
|
||||
update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn
|
||||
update_count = state_ops.assign_add(count_, batch_count) # n_AB in eqn
|
||||
prev_count = update_count - batch_count # n_A in update equation
|
||||
|
||||
# We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount)
|
||||
@ -2660,15 +2660,15 @@ def streaming_covariance(predictions,
|
||||
update_comoment = state_ops.assign_add(comoment, delta_comoment)
|
||||
|
||||
covariance = array_ops.where(
|
||||
math_ops.less_equal(count, 1.),
|
||||
math_ops.less_equal(count_, 1.),
|
||||
float('nan'),
|
||||
math_ops.truediv(comoment, count - 1),
|
||||
math_ops.truediv(comoment, count_ - 1),
|
||||
name='covariance')
|
||||
with ops.control_dependencies([update_comoment]):
|
||||
update_op = array_ops.where(
|
||||
math_ops.less_equal(count, 1.),
|
||||
math_ops.less_equal(count_, 1.),
|
||||
float('nan'),
|
||||
math_ops.truediv(comoment, count - 1),
|
||||
math_ops.truediv(comoment, count_ - 1),
|
||||
name='update_op')
|
||||
|
||||
if metrics_collections:
|
||||
@ -3124,9 +3124,71 @@ def aggregate_metric_map(names_to_tuples):
|
||||
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
|
||||
|
||||
|
||||
def count(values,
|
||||
weights=None,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
name=None):
|
||||
"""Computes the number of examples, or sum of `weights`.
|
||||
|
||||
When evaluating some metric (e.g. mean) on one or more subsets of the data,
|
||||
this auxiliary metric is useful for keeping track of how many examples there
|
||||
are in each subset.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
|
||||
Args:
|
||||
values: A `Tensor` of arbitrary dimensions. Only it's shape is used.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`labels`, and must be broadcastable to `labels` (i.e., all dimensions
|
||||
must be either `1`, or the same as the corresponding `labels`
|
||||
dimension).
|
||||
metrics_collections: An optional list of collections that the metric
|
||||
value variable should be added to.
|
||||
updates_collections: An optional list of collections that the metric update
|
||||
ops should be added to.
|
||||
name: An optional variable_scope name.
|
||||
|
||||
Returns:
|
||||
count: A `Tensor` representing the current value of the metric.
|
||||
update_op: An operation that accumulates the metric from a batch of data.
|
||||
|
||||
Raises:
|
||||
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
|
||||
or if either `metrics_collections` or `updates_collections` are not a list
|
||||
or tuple.
|
||||
"""
|
||||
|
||||
with variable_scope.variable_scope(name, 'count', (values, weights)):
|
||||
count_ = _create_local('count', shape=[])
|
||||
|
||||
if weights is None:
|
||||
num_values = math_ops.to_float(array_ops.size(values))
|
||||
else:
|
||||
_, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
|
||||
predictions=values,
|
||||
labels=None,
|
||||
weights=weights)
|
||||
weights = weights_broadcast_ops.broadcast_weights(
|
||||
math_ops.to_float(weights), values)
|
||||
num_values = math_ops.reduce_sum(weights)
|
||||
|
||||
with ops.control_dependencies([values]):
|
||||
update_op = state_ops.assign_add(count_, num_values)
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, count_)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return count_, update_op
|
||||
|
||||
|
||||
__all__ = [
|
||||
'aggregate_metric_map',
|
||||
'aggregate_metrics',
|
||||
'count',
|
||||
'sparse_recall_at_top_k',
|
||||
'streaming_accuracy',
|
||||
'streaming_auc',
|
||||
|
@ -6170,5 +6170,163 @@ class AggregateMetricMapTest(test.TestCase):
|
||||
self.assertEqual(4, names_to_values['m2'].eval())
|
||||
|
||||
|
||||
class CountTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
ops.reset_default_graph()
|
||||
|
||||
def testVars(self):
|
||||
metrics.count(array_ops.ones([4, 3]))
|
||||
_assert_local_variables(self, ['count/count:0'])
|
||||
|
||||
def testMetricsCollection(self):
|
||||
my_collection_name = '__metrics__'
|
||||
mean, _ = metrics.count(
|
||||
array_ops.ones([4, 3]), metrics_collections=[my_collection_name])
|
||||
self.assertListEqual(ops.get_collection(my_collection_name), [mean])
|
||||
|
||||
def testUpdatesCollection(self):
|
||||
my_collection_name = '__updates__'
|
||||
_, update_op = metrics.count(
|
||||
array_ops.ones([4, 3]), updates_collections=[my_collection_name])
|
||||
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session() as sess:
|
||||
values_queue = data_flow_ops.FIFOQueue(
|
||||
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
|
||||
_enqueue_vector(sess, values_queue, [0, 1])
|
||||
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
|
||||
_enqueue_vector(sess, values_queue, [6.5, 0])
|
||||
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
|
||||
values = values_queue.dequeue()
|
||||
|
||||
result, update_op = metrics.count(values)
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
for _ in range(4):
|
||||
sess.run(update_op)
|
||||
self.assertAlmostEqual(8.0, sess.run(result), 5)
|
||||
|
||||
def testUpdateOpsReturnsCurrentValue(self):
|
||||
with self.test_session() as sess:
|
||||
values_queue = data_flow_ops.FIFOQueue(
|
||||
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
|
||||
_enqueue_vector(sess, values_queue, [0, 1])
|
||||
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
|
||||
_enqueue_vector(sess, values_queue, [6.5, 0])
|
||||
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
|
||||
values = values_queue.dequeue()
|
||||
|
||||
result, update_op = metrics.count(values)
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
|
||||
self.assertAlmostEqual(2.0, sess.run(update_op), 5)
|
||||
self.assertAlmostEqual(4.0, sess.run(update_op), 5)
|
||||
self.assertAlmostEqual(6.0, sess.run(update_op), 5)
|
||||
self.assertAlmostEqual(8.0, sess.run(update_op), 5)
|
||||
|
||||
self.assertAlmostEqual(8.0, sess.run(result), 5)
|
||||
|
||||
def test1dWeightedValues(self):
|
||||
with self.test_session() as sess:
|
||||
# Create the queue that populates the values.
|
||||
values_queue = data_flow_ops.FIFOQueue(
|
||||
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
|
||||
_enqueue_vector(sess, values_queue, [0, 1])
|
||||
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
|
||||
_enqueue_vector(sess, values_queue, [6.5, 0])
|
||||
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
|
||||
values = values_queue.dequeue()
|
||||
|
||||
# Create the queue that populates the weighted labels.
|
||||
weights_queue = data_flow_ops.FIFOQueue(
|
||||
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
|
||||
_enqueue_vector(sess, weights_queue, [0.5])
|
||||
_enqueue_vector(sess, weights_queue, [0])
|
||||
_enqueue_vector(sess, weights_queue, [0])
|
||||
_enqueue_vector(sess, weights_queue, [1.2])
|
||||
weights = weights_queue.dequeue()
|
||||
|
||||
result, update_op = metrics.count(values, weights)
|
||||
|
||||
variables.local_variables_initializer().run()
|
||||
for _ in range(4):
|
||||
update_op.eval()
|
||||
self.assertAlmostEqual(3.4, result.eval(), 5)
|
||||
|
||||
def test1dWeightedValues_placeholders(self):
|
||||
with self.test_session() as sess:
|
||||
# Create the queue that populates the values.
|
||||
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
|
||||
values = array_ops.placeholder(dtype=dtypes_lib.float32)
|
||||
|
||||
# Create the queue that populates the weighted labels.
|
||||
weights_queue = data_flow_ops.FIFOQueue(
|
||||
4, dtypes=dtypes_lib.float32, shapes=(1,))
|
||||
_enqueue_vector(sess, weights_queue, 0.5, shape=(1,))
|
||||
_enqueue_vector(sess, weights_queue, 0, shape=(1,))
|
||||
_enqueue_vector(sess, weights_queue, 0, shape=(1,))
|
||||
_enqueue_vector(sess, weights_queue, 1.2, shape=(1,))
|
||||
weights = weights_queue.dequeue()
|
||||
|
||||
result, update_op = metrics.count(values, weights)
|
||||
|
||||
variables.local_variables_initializer().run()
|
||||
for i in range(4):
|
||||
update_op.eval(feed_dict={values: feed_values[i]})
|
||||
self.assertAlmostEqual(3.4, result.eval(), 5)
|
||||
|
||||
def test2dWeightedValues(self):
|
||||
with self.test_session() as sess:
|
||||
# Create the queue that populates the values.
|
||||
values_queue = data_flow_ops.FIFOQueue(
|
||||
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
|
||||
_enqueue_vector(sess, values_queue, [0, 1])
|
||||
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
|
||||
_enqueue_vector(sess, values_queue, [6.5, 0])
|
||||
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
|
||||
values = values_queue.dequeue()
|
||||
|
||||
# Create the queue that populates the weighted labels.
|
||||
weights_queue = data_flow_ops.FIFOQueue(
|
||||
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
|
||||
_enqueue_vector(sess, weights_queue, [1.1, 1])
|
||||
_enqueue_vector(sess, weights_queue, [1, 0])
|
||||
_enqueue_vector(sess, weights_queue, [0, 1])
|
||||
_enqueue_vector(sess, weights_queue, [0, 0])
|
||||
weights = weights_queue.dequeue()
|
||||
|
||||
result, update_op = metrics.count(values, weights)
|
||||
|
||||
variables.local_variables_initializer().run()
|
||||
for _ in range(4):
|
||||
update_op.eval()
|
||||
self.assertAlmostEqual(4.1, result.eval(), 5)
|
||||
|
||||
def test2dWeightedValues_placeholders(self):
|
||||
with self.test_session() as sess:
|
||||
# Create the queue that populates the values.
|
||||
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
|
||||
values = array_ops.placeholder(dtype=dtypes_lib.float32)
|
||||
|
||||
# Create the queue that populates the weighted labels.
|
||||
weights_queue = data_flow_ops.FIFOQueue(
|
||||
4, dtypes=dtypes_lib.float32, shapes=(2,))
|
||||
_enqueue_vector(sess, weights_queue, [1.1, 1], shape=(2,))
|
||||
_enqueue_vector(sess, weights_queue, [1, 0], shape=(2,))
|
||||
_enqueue_vector(sess, weights_queue, [0, 1], shape=(2,))
|
||||
_enqueue_vector(sess, weights_queue, [0, 0], shape=(2,))
|
||||
weights = weights_queue.dequeue()
|
||||
|
||||
result, update_op = metrics.count(values, weights)
|
||||
|
||||
variables.local_variables_initializer().run()
|
||||
for i in range(4):
|
||||
update_op.eval(feed_dict={values: feed_values[i]})
|
||||
self.assertAlmostEqual(4.1, result.eval(), 5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user