From 62df65c7255e2a8878cd29f66fe80ff8952de157 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 20 Oct 2017 17:46:17 -0700 Subject: [PATCH] Add dtype argument to Mean and Accuracy object-oriented metrics. PiperOrigin-RevId: 172957714 --- .../contrib/eager/python/metrics_impl.py | 27 +++++++++++-------- .../contrib/eager/python/metrics_test.py | 20 ++++++++++++++ 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 2a624b218cc..2139c2b4b98 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -198,13 +198,19 @@ class Mean(Metric): # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? # Or defaults to type of the input if it is tf.float32, else tf.float64? - def build(self, values, weights=None): - del values, weights # build() does not use call's arguments + def __init__(self, name=None, dtype=dtypes.float64): + super(Mean, self).__init__(name=name) + self.dtype = dtype + + def build(self, *args, **kwargs): + # build() does not use call's arguments, by using *args, **kwargs + # we make it easier to inherit from Mean(). + del args, kwargs self.numer = self.add_variable(name="numer", shape=(), - dtype=dtypes.float64, + dtype=self.dtype, initializer=init_ops.zeros_initializer) self.denom = self.add_variable(name="denom", shape=(), - dtype=dtypes.float64, + dtype=self.dtype, initializer=init_ops.zeros_initializer) def call(self, values, weights=None): @@ -219,13 +225,13 @@ class Mean(Metric): """ if weights is None: self.denom.assign_add( - math_ops.cast(array_ops.size(values), dtypes.float64)) + math_ops.cast(array_ops.size(values), self.dtype)) values = math_ops.reduce_sum(values) - self.numer.assign_add(math_ops.cast(values, dtypes.float64)) + self.numer.assign_add(math_ops.cast(values, self.dtype)) else: - weights = math_ops.cast(weights, dtypes.float64) + weights = math_ops.cast(weights, self.dtype) self.denom.assign_add(math_ops.reduce_sum(weights)) - values = math_ops.cast(values, dtypes.float64) * weights + values = math_ops.cast(values, self.dtype) * weights self.numer.assign_add(math_ops.reduce_sum(values)) def result(self): @@ -235,9 +241,8 @@ class Mean(Metric): class Accuracy(Mean): """Calculates how often `predictions` matches `labels`.""" - def build(self, labels, predictions, weights=None): - del labels, predictions, weights - super(Accuracy, self).build(None) # Arguments are unused + def __init__(self, name=None, dtype=dtypes.float64): + super(Accuracy, self).__init__(name=name, dtype=dtype) def call(self, labels, predictions, weights=None): """Accumulate accuracy statistics. diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index bfb79cd72e0..9743666c892 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -34,6 +34,8 @@ class MetricsTest(test.TestCase): m(1000) m([10000.0, 100000.0]) self.assertEqual(111111.0/6, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) def testWeightedMean(self): m = metrics.Mean() @@ -41,6 +43,14 @@ class MetricsTest(test.TestCase): m([500000, 5000, 500]) # weights of 1 each self.assertNear(535521/4.5, m.result().numpy(), 0.001) + def testMeanDtype(self): + # Can override default dtype of float64. + m = metrics.Mean(dtype=dtypes.float32) + m([0, 2]) + self.assertEqual(1, m.result().numpy()) + self.assertEqual(dtypes.float32, m.dtype) + self.assertEqual(dtypes.float32, m.result().dtype) + def testAccuracy(self): m = metrics.Accuracy() m([0, 1, 2, 3], [0, 0, 0, 0]) # 1 correct @@ -49,6 +59,8 @@ class MetricsTest(test.TestCase): m([6], [6]) # 1 correct m([7], [2]) # 0 correct self.assertEqual(3.0/8, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) def testWeightedAccuracy(self): m = metrics.Accuracy() @@ -60,6 +72,14 @@ class MetricsTest(test.TestCase): m([7], [2]) # 0 correct, weight 1 self.assertEqual(2.5/5, m.result().numpy()) + def testAccuracyDtype(self): + # Can override default dtype of float64. + m = metrics.Accuracy(dtype=dtypes.float32) + m([0, 0], [0, 1]) + self.assertEqual(0.5, m.result().numpy()) + self.assertEqual(dtypes.float32, m.dtype) + self.assertEqual(dtypes.float32, m.result().dtype) + def testTwoMeans(self): # Verify two metrics with the same class and name don't # accidentally share state.