Add test case for checking multiple metric instances with default parameters.
PiperOrigin-RevId: 246208421
This commit is contained in:
		
							parent
							
								
									96f5824e31
								
							
						
					
					
						commit
						34f34bf7ed
					
				@ -158,8 +158,8 @@ class KerasSumTest(test.TestCase):
 | 
			
		||||
    self.assertEqual(600., self.evaluate(restore_sum.result()))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@test_util.run_all_in_graph_and_eager_modes
 | 
			
		||||
class KerasMeanTest(test.TestCase):
 | 
			
		||||
@keras_parameterized.run_all_keras_modes
 | 
			
		||||
class KerasMeanTest(keras_parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  # TODO(b/120949004): Re-enable garbage collection check
 | 
			
		||||
  # @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
 | 
			
		||||
@ -294,6 +294,43 @@ class KerasMeanTest(test.TestCase):
 | 
			
		||||
    self.assertEqual(200., self.evaluate(restore_mean.result()))
 | 
			
		||||
    self.assertEqual(3, self.evaluate(restore_mean.count))
 | 
			
		||||
 | 
			
		||||
  def test_multiple_instances(self):
 | 
			
		||||
    m = metrics.Mean()
 | 
			
		||||
    m2 = metrics.Mean()
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(m.name, 'mean')
 | 
			
		||||
    self.assertEqual(m2.name, 'mean')
 | 
			
		||||
 | 
			
		||||
    self.assertEqual([v.name for v in m.variables],
 | 
			
		||||
                     testing_utils.get_expected_metric_variable_names(
 | 
			
		||||
                         ['total', 'count']))
 | 
			
		||||
    self.assertEqual([v.name for v in m2.variables],
 | 
			
		||||
                     testing_utils.get_expected_metric_variable_names(
 | 
			
		||||
                         ['total', 'count'], name_suffix='_1'))
 | 
			
		||||
 | 
			
		||||
    self.evaluate(variables.variables_initializer(m.variables))
 | 
			
		||||
    self.evaluate(variables.variables_initializer(m2.variables))
 | 
			
		||||
 | 
			
		||||
    # check initial state
 | 
			
		||||
    self.assertEqual(self.evaluate(m.total), 0)
 | 
			
		||||
    self.assertEqual(self.evaluate(m.count), 0)
 | 
			
		||||
    self.assertEqual(self.evaluate(m2.total), 0)
 | 
			
		||||
    self.assertEqual(self.evaluate(m2.count), 0)
 | 
			
		||||
 | 
			
		||||
    # check __call__()
 | 
			
		||||
    self.assertEqual(self.evaluate(m(100)), 100)
 | 
			
		||||
    self.assertEqual(self.evaluate(m.total), 100)
 | 
			
		||||
    self.assertEqual(self.evaluate(m.count), 1)
 | 
			
		||||
    self.assertEqual(self.evaluate(m2.total), 0)
 | 
			
		||||
    self.assertEqual(self.evaluate(m2.count), 0)
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(self.evaluate(m2([63, 10])), 36.5)
 | 
			
		||||
    self.assertEqual(self.evaluate(m2.total), 73)
 | 
			
		||||
    self.assertEqual(self.evaluate(m2.count), 2)
 | 
			
		||||
    self.assertEqual(self.evaluate(m.result()), 100)
 | 
			
		||||
    self.assertEqual(self.evaluate(m.total), 100)
 | 
			
		||||
    self.assertEqual(self.evaluate(m.count), 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@test_util.run_all_in_graph_and_eager_modes
 | 
			
		||||
class KerasAccuracyTest(test.TestCase):
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,7 @@ import threading
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.python import keras
 | 
			
		||||
from tensorflow.python import tf2
 | 
			
		||||
from tensorflow.python.eager import context
 | 
			
		||||
from tensorflow.python.framework import tensor_shape
 | 
			
		||||
from tensorflow.python.framework import test_util
 | 
			
		||||
@ -682,3 +683,12 @@ def get_v2_optimizer(name, **kwargs):
 | 
			
		||||
    raise ValueError(
 | 
			
		||||
        'Could not find requested v2 optimizer: {}\nValid choices: {}'.format(
 | 
			
		||||
            name, list(_V2_OPTIMIZER_MAP.keys())))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_expected_metric_variable_names(var_names, name_suffix=''):
 | 
			
		||||
  """Returns expected metric variable names given names and prefix/suffix."""
 | 
			
		||||
  if tf2.enabled() or context.executing_eagerly():
 | 
			
		||||
    # In V1 eager mode and V2 variable names are not made unique.
 | 
			
		||||
    return [n + ':0' for n in var_names]
 | 
			
		||||
  # In V1 graph mode variable names are made unique using a suffix.
 | 
			
		||||
  return [n + name_suffix + ':0' for n in var_names]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user