Add dtype argument to Mean and Accuracy object-oriented metrics.
PiperOrigin-RevId: 172957714
This commit is contained in:
parent
29c7b46585
commit
62df65c725
@ -198,13 +198,19 @@ class Mean(Metric):
|
||||
# TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64?
|
||||
# Or defaults to type of the input if it is tf.float32, else tf.float64?
|
||||
|
||||
def build(self, values, weights=None):
|
||||
del values, weights # build() does not use call's arguments
|
||||
def __init__(self, name=None, dtype=dtypes.float64):
|
||||
super(Mean, self).__init__(name=name)
|
||||
self.dtype = dtype
|
||||
|
||||
def build(self, *args, **kwargs):
|
||||
# build() does not use call's arguments, by using *args, **kwargs
|
||||
# we make it easier to inherit from Mean().
|
||||
del args, kwargs
|
||||
self.numer = self.add_variable(name="numer", shape=(),
|
||||
dtype=dtypes.float64,
|
||||
dtype=self.dtype,
|
||||
initializer=init_ops.zeros_initializer)
|
||||
self.denom = self.add_variable(name="denom", shape=(),
|
||||
dtype=dtypes.float64,
|
||||
dtype=self.dtype,
|
||||
initializer=init_ops.zeros_initializer)
|
||||
|
||||
def call(self, values, weights=None):
|
||||
@ -219,13 +225,13 @@ class Mean(Metric):
|
||||
"""
|
||||
if weights is None:
|
||||
self.denom.assign_add(
|
||||
math_ops.cast(array_ops.size(values), dtypes.float64))
|
||||
math_ops.cast(array_ops.size(values), self.dtype))
|
||||
values = math_ops.reduce_sum(values)
|
||||
self.numer.assign_add(math_ops.cast(values, dtypes.float64))
|
||||
self.numer.assign_add(math_ops.cast(values, self.dtype))
|
||||
else:
|
||||
weights = math_ops.cast(weights, dtypes.float64)
|
||||
weights = math_ops.cast(weights, self.dtype)
|
||||
self.denom.assign_add(math_ops.reduce_sum(weights))
|
||||
values = math_ops.cast(values, dtypes.float64) * weights
|
||||
values = math_ops.cast(values, self.dtype) * weights
|
||||
self.numer.assign_add(math_ops.reduce_sum(values))
|
||||
|
||||
def result(self):
|
||||
@ -235,9 +241,8 @@ class Mean(Metric):
|
||||
class Accuracy(Mean):
|
||||
"""Calculates how often `predictions` matches `labels`."""
|
||||
|
||||
def build(self, labels, predictions, weights=None):
|
||||
del labels, predictions, weights
|
||||
super(Accuracy, self).build(None) # Arguments are unused
|
||||
def __init__(self, name=None, dtype=dtypes.float64):
|
||||
super(Accuracy, self).__init__(name=name, dtype=dtype)
|
||||
|
||||
def call(self, labels, predictions, weights=None):
|
||||
"""Accumulate accuracy statistics.
|
||||
|
@ -34,6 +34,8 @@ class MetricsTest(test.TestCase):
|
||||
m(1000)
|
||||
m([10000.0, 100000.0])
|
||||
self.assertEqual(111111.0/6, m.result().numpy())
|
||||
self.assertEqual(dtypes.float64, m.dtype)
|
||||
self.assertEqual(dtypes.float64, m.result().dtype)
|
||||
|
||||
def testWeightedMean(self):
|
||||
m = metrics.Mean()
|
||||
@ -41,6 +43,14 @@ class MetricsTest(test.TestCase):
|
||||
m([500000, 5000, 500]) # weights of 1 each
|
||||
self.assertNear(535521/4.5, m.result().numpy(), 0.001)
|
||||
|
||||
def testMeanDtype(self):
|
||||
# Can override default dtype of float64.
|
||||
m = metrics.Mean(dtype=dtypes.float32)
|
||||
m([0, 2])
|
||||
self.assertEqual(1, m.result().numpy())
|
||||
self.assertEqual(dtypes.float32, m.dtype)
|
||||
self.assertEqual(dtypes.float32, m.result().dtype)
|
||||
|
||||
def testAccuracy(self):
|
||||
m = metrics.Accuracy()
|
||||
m([0, 1, 2, 3], [0, 0, 0, 0]) # 1 correct
|
||||
@ -49,6 +59,8 @@ class MetricsTest(test.TestCase):
|
||||
m([6], [6]) # 1 correct
|
||||
m([7], [2]) # 0 correct
|
||||
self.assertEqual(3.0/8, m.result().numpy())
|
||||
self.assertEqual(dtypes.float64, m.dtype)
|
||||
self.assertEqual(dtypes.float64, m.result().dtype)
|
||||
|
||||
def testWeightedAccuracy(self):
|
||||
m = metrics.Accuracy()
|
||||
@ -60,6 +72,14 @@ class MetricsTest(test.TestCase):
|
||||
m([7], [2]) # 0 correct, weight 1
|
||||
self.assertEqual(2.5/5, m.result().numpy())
|
||||
|
||||
def testAccuracyDtype(self):
|
||||
# Can override default dtype of float64.
|
||||
m = metrics.Accuracy(dtype=dtypes.float32)
|
||||
m([0, 0], [0, 1])
|
||||
self.assertEqual(0.5, m.result().numpy())
|
||||
self.assertEqual(dtypes.float32, m.dtype)
|
||||
self.assertEqual(dtypes.float32, m.result().dtype)
|
||||
|
||||
def testTwoMeans(self):
|
||||
# Verify two metrics with the same class and name don't
|
||||
# accidentally share state.
|
||||
|
Loading…
Reference in New Issue
Block a user