Fix a number of deserialization error messages. While print(string, obj) is valid syntax, ValueError(string, obj) is not.

PiperOrigin-RevId: 309877228
Change-Id: Id6f7f8b0207a0c432c232c3bf4c80eb6c1ae5471
This commit is contained in:
Francois Chollet 2020-05-04 21:23:05 -07:00 committed by TensorFlower Gardener
parent 7705ee85ef
commit 10fdfcf50e
8 changed files with 20 additions and 8 deletions

View File

@ -487,4 +487,4 @@ def get(identifier):
else:
raise TypeError(
'Could not interpret activation function identifier: {}'.format(
repr(identifier)))
identifier))

View File

@ -1880,8 +1880,8 @@ def get(identifier):
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret '
'loss function identifier:', identifier)
raise ValueError(
'Could not interpret loss function identifier: {}'.format(identifier))
LABEL_DTYPES_FOR_LOSSES = {

View File

@ -234,6 +234,10 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, 'Invalid Reduction Key Bar.'):
mse_obj(y, y)
def test_deserialization_error(self):
with self.assertRaisesRegex(ValueError, 'Could not interpret loss'):
losses.get(0)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MeanSquaredErrorTest(test.TestCase):

View File

@ -3479,9 +3479,8 @@ def get(identifier):
elif callable(identifier):
return identifier
else:
error_msg = 'Could not interpret metric function identifier: {}'.format(
identifier)
raise ValueError(error_msg)
raise ValueError(
'Could not interpret metric function identifier: {}'.format(identifier))
def is_built_in(cls):

View File

@ -899,4 +899,5 @@ def get(identifier):
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
else:
raise ValueError('Could not interpret optimizer identifier:', identifier)
raise ValueError(
'Could not interpret optimizer identifier: {}'.format(identifier))

View File

@ -253,6 +253,9 @@ class KerasOptimizersTest(keras_parameterized.TestCase):
batch_size=5,
verbose=0)
def test_deserialization_error(self):
with self.assertRaisesRegex(ValueError, 'Could not interpret optimizer'):
keras.optimizers.get(0)
if __name__ == '__main__':
test.main()

View File

@ -312,4 +312,5 @@ def get(identifier):
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret regularizer identifier:', identifier)
raise ValueError(
'Could not interpret regularizer identifier: {}'.format(identifier))

View File

@ -199,6 +199,10 @@ class KerasRegularizersTest(keras_parameterized.TestCase,
# - 4 from activity regularizers on the shared_dense layer.
self.assertLen(model.losses, 9)
def test_deserialization_error(self):
with self.assertRaisesRegex(ValueError, 'Could not interpret regularizer'):
keras.regularizers.get(0)
if __name__ == '__main__':
test.main()