Fix mask function serialization for lambda layer.

This was causing issue in model_to_estimator where model was serialize and reconstructed, the mask function was silently dropped in that case.

PiperOrigin-RevId: 246888369
This commit is contained in:
Scott Zhu 2019-05-06 13:42:09 -07:00 committed by TensorFlower Gardener
parent 861d5471c6
commit 2707117ab8
2 changed files with 96 additions and 74 deletions

View File

@ -803,87 +803,66 @@ class Lambda(Layer):
return self.mask
def get_config(self):
module = self.function.__module__
if isinstance(self.function, python_types.LambdaType):
function = generic_utils.func_dump(self.function)
function_type = 'lambda'
else:
function = self.function.__name__
function_type = 'function'
output_shape_module = None
if isinstance(self._output_shape, python_types.LambdaType):
output_shape = generic_utils.func_dump(self._output_shape)
output_shape_type = 'lambda'
output_shape_module = self._output_shape.__module__
elif callable(self._output_shape):
output_shape = self._output_shape.__name__
output_shape_type = 'function'
output_shape_module = self._output_shape.__module__
else:
output_shape = self._output_shape
output_shape_type = 'raw'
function_config = self._serialize_function_to_config(self.function)
output_shape_config = self._serialize_function_to_config(self._output_shape,
allow_raw=True)
config = {
'function': function,
'module': module,
'function_type': function_type,
'output_shape': output_shape,
'output_shape_type': output_shape_type,
'output_shape_module': output_shape_module,
'arguments': self.arguments
'function': function_config[0],
'function_type': function_config[1],
'module': function_config[2],
'output_shape': output_shape_config[0],
'output_shape_type': output_shape_config[1],
'output_shape_module': output_shape_config[2],
}
if self.mask is not None:
mask_config = self._serialize_function_to_config(self.mask)
config.update({
'mask': mask_config[0],
'mask_type': mask_config[1],
'mask_module': mask_config[2]
})
config['arguments'] = self.arguments
base_config = super(Lambda, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _serialize_function_to_config(self, inputs, allow_raw=False):
if isinstance(inputs, python_types.LambdaType):
output = generic_utils.func_dump(inputs)
output_type = 'lambda'
module = inputs.__module__
elif callable(inputs):
output = inputs.__name__
output_type = 'function'
module = inputs.__module__
elif allow_raw:
output = inputs
output_type = 'raw'
module = None
else:
raise ValueError(
'Invalid input for serialization, type: %s ' % type(inputs))
return output, output_type, module
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy()
globs = globals()
module = config.pop('module', None)
if module in sys.modules:
globs.update(sys.modules[module].__dict__)
elif module is not None:
# Note: we don't know the name of the function if it's a lambda.
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
'It may cause errors.'.format(module)
, UserWarning)
if custom_objects:
globs.update(custom_objects)
function_type = config.pop('function_type')
if function_type == 'function':
# Simple lookup in custom objects
function = generic_utils.deserialize_keras_object(
config['function'],
custom_objects=custom_objects,
printable_module_name='function in Lambda layer')
elif function_type == 'lambda':
# Unsafe deserialization from bytecode
function = generic_utils.func_load(config['function'], globs=globs)
else:
raise TypeError('Unknown function type:', function_type)
function = cls._parse_function_from_config(
config, custom_objects, 'function', 'module', 'function_type')
output_shape_module = config.pop('output_shape_module', None)
if output_shape_module in sys.modules:
globs.update(sys.modules[output_shape_module].__dict__)
elif output_shape_module is not None:
# Note: we don't know the name of the function if it's a lambda.
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
'It may cause errors.'.format(output_shape_module)
, UserWarning)
output_shape_type = config.pop('output_shape_type')
if output_shape_type == 'function':
# Simple lookup in custom objects
output_shape = generic_utils.deserialize_keras_object(
config['output_shape'],
custom_objects=custom_objects,
printable_module_name='output_shape function in Lambda layer')
elif output_shape_type == 'lambda':
# Unsafe deserialization from bytecode
output_shape = generic_utils.func_load(config['output_shape'],
globs=globs)
output_shape = cls._parse_function_from_config(
config, custom_objects, 'output_shape', 'output_shape_module',
'output_shape_type')
if 'mask' in config:
mask = cls._parse_function_from_config(
config, custom_objects, 'mask', 'mask_module', 'mask_type')
else:
output_shape = config['output_shape']
mask = None
config['function'] = function
config['output_shape'] = output_shape
config['mask'] = mask
# If arguments were numpy array, they have been saved as
# list. We need to recover the ndarray
@ -895,10 +874,40 @@ class Lambda(Layer):
# Overwrite the argument with its numpy translation
config['arguments'][key] = np.array(arg_dict['value'])
config['function'] = function
config['output_shape'] = output_shape
return cls(**config)
@classmethod
def _parse_function_from_config(
cls, config, custom_objects, func_attr_name, module_attr_name,
func_type_attr_name):
globs = globals()
module = config.pop(module_attr_name, None)
if module in sys.modules:
globs.update(sys.modules[module].__dict__)
elif module is not None:
# Note: we don't know the name of the function if it's a lambda.
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
'It may cause errors.'.format(module)
, UserWarning)
if custom_objects:
globs.update(custom_objects)
function_type = config.pop(func_type_attr_name)
if function_type == 'function':
# Simple lookup in custom objects
function = generic_utils.deserialize_keras_object(
config[func_attr_name],
custom_objects=custom_objects,
printable_module_name='function in Lambda layer')
elif function_type == 'lambda':
# Unsafe deserialization from bytecode
function = generic_utils.func_load(
config[func_attr_name], globs=globs)
elif function_type == 'raw':
function = config[func_attr_name]
else:
raise TypeError('Unknown function type:', function_type)
return function
@keras_export('keras.layers.Dense')
class Dense(Layer):

View File

@ -107,12 +107,14 @@ class LambdaLayerTest(keras_parameterized.TestCase):
'class_name': 'Lambda',
'config': config
})
self.assertEqual(ld.function(3), 4)
# test with lambda
ld = keras.layers.Lambda(
lambda x: keras.backend.concatenate([math_ops.square(x), x]))
config = ld.get_config()
ld = keras.layers.Lambda.from_config(config)
self.assertAllEqual(self.evaluate(ld.function([3])), [9, 3])
def test_lambda_multiple_inputs(self):
ld = keras.layers.Lambda(lambda x: x[0], output_shape=lambda x: x[0])
@ -184,14 +186,25 @@ class LambdaLayerTest(keras_parameterized.TestCase):
def test_lambda_config_serialization(self):
# Test serialization with output_shape and output_shape_type
layer = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
layer = keras.layers.Lambda(
lambda x: x + 1,
output_shape=(1, 1),
mask=lambda i, m: m)
layer(keras.backend.variable(np.ones((1, 1))))
config = layer.get_config()
layer = keras.layers.deserialize({
'class_name': 'Lambda',
'config': config
})
self.assertAllEqual(layer.function(1), 2)
self.assertAllEqual(layer._output_shape, (1, 1))
self.assertAllEqual(layer.mask(1, True), True)
layer = keras.layers.Lambda.from_config(config)
self.assertAllEqual(layer.function(1), 2)
self.assertAllEqual(layer._output_shape, (1, 1))
self.assertAllEqual(layer.mask(1, True), True)
def test_lambda_with_variable(self):