RELNOTES: Add an ignore_unknown argument to parse_values which suppresses ValueError for unknown hyperparameter types. Such hyperparameter are ignored.
parse_values('a=1,b=foo', {a: int}) Raises a ValueError
parse_values('a=1,b=foo', {a: int}, ignore_unknown=True) does not raise a ValueError, and returns {'a': 1}
PiperOrigin-RevId: 225117666
			
			
This commit is contained in:
		
							parent
							
								
									8ac99aa0ec
								
							
						
					
					
						commit
						c6245fa0b4
					
				@ -187,7 +187,7 @@ def _cast_to_type_if_compatible(name, param_type, value):
 | 
			
		||||
  return param_type(value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_values(values, type_map):
 | 
			
		||||
def parse_values(values, type_map, ignore_unknown=False):
 | 
			
		||||
  """Parses hyperparameter values from a string into a python map.
 | 
			
		||||
 | 
			
		||||
  `values` is a string containing comma-separated `name=value` pairs.
 | 
			
		||||
@ -233,6 +233,9 @@ def parse_values(values, type_map):
 | 
			
		||||
      type T if either V has type T, or V is a list of elements of type T.
 | 
			
		||||
      Hence, for a multidimensional parameter 'x' taking float values,
 | 
			
		||||
      'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
 | 
			
		||||
    ignore_unknown: Bool. Whether values that are missing a type in type_map
 | 
			
		||||
      should be ignored. If set to True, a ValueError will not be raised for
 | 
			
		||||
      unknown hyperparameter type.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    A python map mapping each name to either:
 | 
			
		||||
@ -260,6 +263,8 @@ def parse_values(values, type_map):
 | 
			
		||||
    m_dict = m.groupdict()
 | 
			
		||||
    name = m_dict['name']
 | 
			
		||||
    if name not in type_map:
 | 
			
		||||
      if ignore_unknown:
 | 
			
		||||
        continue
 | 
			
		||||
      raise ValueError('Unknown hyperparameter type for %s' % name)
 | 
			
		||||
    type_ = type_map[name]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -216,6 +216,14 @@ class HParamsTest(test.TestCase):
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['arr'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['arr'], {1: 10})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithIndexAssigment1_IgnoreUnknown(self):
 | 
			
		||||
    """Assignment to an index position."""
 | 
			
		||||
    parse_dict = hparam.parse_values(
 | 
			
		||||
        'arr[1]=10,b=5', {'arr': int}, ignore_unknown=True)
 | 
			
		||||
    self.assertEqual(len(parse_dict), 1)
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['arr'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['arr'], {1: 10})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithIndexAssigment2(self):
 | 
			
		||||
    """Assignment to multiple index positions."""
 | 
			
		||||
    parse_dict = hparam.parse_values('arr[0]=10,arr[5]=20', {'arr': int})
 | 
			
		||||
@ -223,6 +231,14 @@ class HParamsTest(test.TestCase):
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['arr'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithIndexAssigment2_IgnoreUnknown(self):
 | 
			
		||||
    """Assignment to multiple index positions."""
 | 
			
		||||
    parse_dict = hparam.parse_values(
 | 
			
		||||
        'arr[0]=10,arr[5]=20,foo=bar', {'arr': int}, ignore_unknown=True)
 | 
			
		||||
    self.assertEqual(len(parse_dict), 1)
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['arr'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithIndexAssigment3(self):
 | 
			
		||||
    """Assignment to index positions in multiple names."""
 | 
			
		||||
    parse_dict = hparam.parse_values('arr[0]=10,arr[1]=20,L[5]=100,L[10]=200',
 | 
			
		||||
@ -234,6 +250,17 @@ class HParamsTest(test.TestCase):
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['L'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithIndexAssigment3_IgnoreUnknown(self):
 | 
			
		||||
    """Assignment to index positions in multiple names."""
 | 
			
		||||
    parse_dict = hparam.parse_values(
 | 
			
		||||
        'arr[0]=10,C=5,arr[1]=20,B[0]=kkk,L[5]=100,L[10]=200',
 | 
			
		||||
        {'arr': int, 'L': int}, ignore_unknown=True)
 | 
			
		||||
    self.assertEqual(len(parse_dict), 2)
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['arr'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['arr'], {0: 10, 1: 20})
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['L'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithIndexAssigment4(self):
 | 
			
		||||
    """Assignment of index positions and scalars."""
 | 
			
		||||
    parse_dict = hparam.parse_values('x=10,arr[1]=20,y=30',
 | 
			
		||||
@ -246,6 +273,17 @@ class HParamsTest(test.TestCase):
 | 
			
		||||
    self.assertEqual(parse_dict['x'], 10)
 | 
			
		||||
    self.assertEqual(parse_dict['y'], 30)
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithIndexAssigment4_IgnoreUnknown(self):
 | 
			
		||||
    """Assignment of index positions and scalars."""
 | 
			
		||||
    parse_dict = hparam.parse_values(
 | 
			
		||||
        'x=10,foo[0]=bar,arr[1]=20,zzz=78,y=30',
 | 
			
		||||
        {'x': int, 'y': int, 'arr': int}, ignore_unknown=True)
 | 
			
		||||
    self.assertEqual(len(parse_dict), 3)
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['arr'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['arr'], {1: 20})
 | 
			
		||||
    self.assertEqual(parse_dict['x'], 10)
 | 
			
		||||
    self.assertEqual(parse_dict['y'], 30)
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithIndexAssigment5(self):
 | 
			
		||||
    """Different variable types."""
 | 
			
		||||
    parse_dict = hparam.parse_values('a[0]=5,b[1]=true,c[2]=abc,d[3]=3.14', {
 | 
			
		||||
@ -264,24 +302,55 @@ class HParamsTest(test.TestCase):
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['d'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['d'], {3: 3.14})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithIndexAssigment5_IgnoreUnknown(self):
 | 
			
		||||
    """Different variable types."""
 | 
			
		||||
    parse_dict = hparam.parse_values(
 | 
			
		||||
        'a[0]=5,cc=4,b[1]=true,c[2]=abc,mm=2,d[3]=3.14',
 | 
			
		||||
        {'a': int, 'b': bool, 'c': str, 'd': float},
 | 
			
		||||
        ignore_unknown=True)
 | 
			
		||||
    self.assertEqual(set(parse_dict.keys()), {'a', 'b', 'c', 'd'})
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['a'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['a'], {0: 5})
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['b'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['b'], {1: True})
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['c'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['c'], {2: 'abc'})
 | 
			
		||||
    self.assertTrue(isinstance(parse_dict['d'], dict))
 | 
			
		||||
    self.assertDictEqual(parse_dict['d'], {3: 3.14})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithBadIndexAssigment1(self):
 | 
			
		||||
    """Reject assignment of list to variable type."""
 | 
			
		||||
    with self.assertRaisesRegexp(ValueError,
 | 
			
		||||
                                 r'Assignment of a list to a list index.'):
 | 
			
		||||
      hparam.parse_values('arr[1]=[1,2,3]', {'arr': int})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithBadIndexAssigment1_IgnoreUnknown(self):
 | 
			
		||||
    """Reject assignment of list to variable type."""
 | 
			
		||||
    with self.assertRaisesRegexp(ValueError,
 | 
			
		||||
                                 r'Assignment of a list to a list index.'):
 | 
			
		||||
      hparam.parse_values(
 | 
			
		||||
          'arr[1]=[1,2,3],c=8', {'arr': int}, ignore_unknown=True)
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithBadIndexAssigment2(self):
 | 
			
		||||
    """Reject if type missing."""
 | 
			
		||||
    with self.assertRaisesRegexp(ValueError,
 | 
			
		||||
                                 r'Unknown hyperparameter type for arr'):
 | 
			
		||||
      hparam.parse_values('arr[1]=5', {})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithBadIndexAssigment2_IgnoreUnknown(self):
 | 
			
		||||
    """Ignore missing type."""
 | 
			
		||||
    hparam.parse_values('arr[1]=5', {}, ignore_unknown=True)
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithBadIndexAssigment3(self):
 | 
			
		||||
    """Reject type of the form name[index]."""
 | 
			
		||||
    with self.assertRaisesRegexp(ValueError,
 | 
			
		||||
                                 'Unknown hyperparameter type for arr'):
 | 
			
		||||
      hparam.parse_values('arr[1]=1', {'arr[1]': int})
 | 
			
		||||
 | 
			
		||||
  def testParseValuesWithBadIndexAssigment3_IgnoreUnknown(self):
 | 
			
		||||
    """Ignore type of the form name[index]."""
 | 
			
		||||
    hparam.parse_values('arr[1]=1', {'arr[1]': int}, ignore_unknown=True)
 | 
			
		||||
 | 
			
		||||
  def testWithReusedVariables(self):
 | 
			
		||||
    with self.assertRaisesRegexp(ValueError,
 | 
			
		||||
                                 'Multiple assignments to variable \'x\''):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user