Fix mask function serialization for lambda layer.
This was causing issue in model_to_estimator where model was serialize and reconstructed, the mask function was silently dropped in that case. PiperOrigin-RevId: 246888369
This commit is contained in:
parent
861d5471c6
commit
2707117ab8
@ -803,87 +803,66 @@ class Lambda(Layer):
|
||||
return self.mask
|
||||
|
||||
def get_config(self):
|
||||
module = self.function.__module__
|
||||
if isinstance(self.function, python_types.LambdaType):
|
||||
function = generic_utils.func_dump(self.function)
|
||||
function_type = 'lambda'
|
||||
else:
|
||||
function = self.function.__name__
|
||||
function_type = 'function'
|
||||
|
||||
output_shape_module = None
|
||||
if isinstance(self._output_shape, python_types.LambdaType):
|
||||
output_shape = generic_utils.func_dump(self._output_shape)
|
||||
output_shape_type = 'lambda'
|
||||
output_shape_module = self._output_shape.__module__
|
||||
elif callable(self._output_shape):
|
||||
output_shape = self._output_shape.__name__
|
||||
output_shape_type = 'function'
|
||||
output_shape_module = self._output_shape.__module__
|
||||
else:
|
||||
output_shape = self._output_shape
|
||||
output_shape_type = 'raw'
|
||||
|
||||
function_config = self._serialize_function_to_config(self.function)
|
||||
output_shape_config = self._serialize_function_to_config(self._output_shape,
|
||||
allow_raw=True)
|
||||
config = {
|
||||
'function': function,
|
||||
'module': module,
|
||||
'function_type': function_type,
|
||||
'output_shape': output_shape,
|
||||
'output_shape_type': output_shape_type,
|
||||
'output_shape_module': output_shape_module,
|
||||
'arguments': self.arguments
|
||||
'function': function_config[0],
|
||||
'function_type': function_config[1],
|
||||
'module': function_config[2],
|
||||
'output_shape': output_shape_config[0],
|
||||
'output_shape_type': output_shape_config[1],
|
||||
'output_shape_module': output_shape_config[2],
|
||||
}
|
||||
if self.mask is not None:
|
||||
mask_config = self._serialize_function_to_config(self.mask)
|
||||
config.update({
|
||||
'mask': mask_config[0],
|
||||
'mask_type': mask_config[1],
|
||||
'mask_module': mask_config[2]
|
||||
})
|
||||
config['arguments'] = self.arguments
|
||||
|
||||
base_config = super(Lambda, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def _serialize_function_to_config(self, inputs, allow_raw=False):
|
||||
if isinstance(inputs, python_types.LambdaType):
|
||||
output = generic_utils.func_dump(inputs)
|
||||
output_type = 'lambda'
|
||||
module = inputs.__module__
|
||||
elif callable(inputs):
|
||||
output = inputs.__name__
|
||||
output_type = 'function'
|
||||
module = inputs.__module__
|
||||
elif allow_raw:
|
||||
output = inputs
|
||||
output_type = 'raw'
|
||||
module = None
|
||||
else:
|
||||
raise ValueError(
|
||||
'Invalid input for serialization, type: %s ' % type(inputs))
|
||||
|
||||
return output, output_type, module
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
config = config.copy()
|
||||
globs = globals()
|
||||
module = config.pop('module', None)
|
||||
if module in sys.modules:
|
||||
globs.update(sys.modules[module].__dict__)
|
||||
elif module is not None:
|
||||
# Note: we don't know the name of the function if it's a lambda.
|
||||
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
|
||||
'It may cause errors.'.format(module)
|
||||
, UserWarning)
|
||||
if custom_objects:
|
||||
globs.update(custom_objects)
|
||||
function_type = config.pop('function_type')
|
||||
if function_type == 'function':
|
||||
# Simple lookup in custom objects
|
||||
function = generic_utils.deserialize_keras_object(
|
||||
config['function'],
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='function in Lambda layer')
|
||||
elif function_type == 'lambda':
|
||||
# Unsafe deserialization from bytecode
|
||||
function = generic_utils.func_load(config['function'], globs=globs)
|
||||
else:
|
||||
raise TypeError('Unknown function type:', function_type)
|
||||
function = cls._parse_function_from_config(
|
||||
config, custom_objects, 'function', 'module', 'function_type')
|
||||
|
||||
output_shape_module = config.pop('output_shape_module', None)
|
||||
if output_shape_module in sys.modules:
|
||||
globs.update(sys.modules[output_shape_module].__dict__)
|
||||
elif output_shape_module is not None:
|
||||
# Note: we don't know the name of the function if it's a lambda.
|
||||
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
|
||||
'It may cause errors.'.format(output_shape_module)
|
||||
, UserWarning)
|
||||
output_shape_type = config.pop('output_shape_type')
|
||||
if output_shape_type == 'function':
|
||||
# Simple lookup in custom objects
|
||||
output_shape = generic_utils.deserialize_keras_object(
|
||||
config['output_shape'],
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='output_shape function in Lambda layer')
|
||||
elif output_shape_type == 'lambda':
|
||||
# Unsafe deserialization from bytecode
|
||||
output_shape = generic_utils.func_load(config['output_shape'],
|
||||
globs=globs)
|
||||
output_shape = cls._parse_function_from_config(
|
||||
config, custom_objects, 'output_shape', 'output_shape_module',
|
||||
'output_shape_type')
|
||||
if 'mask' in config:
|
||||
mask = cls._parse_function_from_config(
|
||||
config, custom_objects, 'mask', 'mask_module', 'mask_type')
|
||||
else:
|
||||
output_shape = config['output_shape']
|
||||
mask = None
|
||||
|
||||
config['function'] = function
|
||||
config['output_shape'] = output_shape
|
||||
config['mask'] = mask
|
||||
|
||||
# If arguments were numpy array, they have been saved as
|
||||
# list. We need to recover the ndarray
|
||||
@ -895,10 +874,40 @@ class Lambda(Layer):
|
||||
# Overwrite the argument with its numpy translation
|
||||
config['arguments'][key] = np.array(arg_dict['value'])
|
||||
|
||||
config['function'] = function
|
||||
config['output_shape'] = output_shape
|
||||
return cls(**config)
|
||||
|
||||
@classmethod
|
||||
def _parse_function_from_config(
|
||||
cls, config, custom_objects, func_attr_name, module_attr_name,
|
||||
func_type_attr_name):
|
||||
globs = globals()
|
||||
module = config.pop(module_attr_name, None)
|
||||
if module in sys.modules:
|
||||
globs.update(sys.modules[module].__dict__)
|
||||
elif module is not None:
|
||||
# Note: we don't know the name of the function if it's a lambda.
|
||||
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
|
||||
'It may cause errors.'.format(module)
|
||||
, UserWarning)
|
||||
if custom_objects:
|
||||
globs.update(custom_objects)
|
||||
function_type = config.pop(func_type_attr_name)
|
||||
if function_type == 'function':
|
||||
# Simple lookup in custom objects
|
||||
function = generic_utils.deserialize_keras_object(
|
||||
config[func_attr_name],
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='function in Lambda layer')
|
||||
elif function_type == 'lambda':
|
||||
# Unsafe deserialization from bytecode
|
||||
function = generic_utils.func_load(
|
||||
config[func_attr_name], globs=globs)
|
||||
elif function_type == 'raw':
|
||||
function = config[func_attr_name]
|
||||
else:
|
||||
raise TypeError('Unknown function type:', function_type)
|
||||
return function
|
||||
|
||||
|
||||
@keras_export('keras.layers.Dense')
|
||||
class Dense(Layer):
|
||||
|
@ -107,12 +107,14 @@ class LambdaLayerTest(keras_parameterized.TestCase):
|
||||
'class_name': 'Lambda',
|
||||
'config': config
|
||||
})
|
||||
self.assertEqual(ld.function(3), 4)
|
||||
|
||||
# test with lambda
|
||||
ld = keras.layers.Lambda(
|
||||
lambda x: keras.backend.concatenate([math_ops.square(x), x]))
|
||||
config = ld.get_config()
|
||||
ld = keras.layers.Lambda.from_config(config)
|
||||
self.assertAllEqual(self.evaluate(ld.function([3])), [9, 3])
|
||||
|
||||
def test_lambda_multiple_inputs(self):
|
||||
ld = keras.layers.Lambda(lambda x: x[0], output_shape=lambda x: x[0])
|
||||
@ -184,14 +186,25 @@ class LambdaLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_lambda_config_serialization(self):
|
||||
# Test serialization with output_shape and output_shape_type
|
||||
layer = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
|
||||
layer = keras.layers.Lambda(
|
||||
lambda x: x + 1,
|
||||
output_shape=(1, 1),
|
||||
mask=lambda i, m: m)
|
||||
layer(keras.backend.variable(np.ones((1, 1))))
|
||||
config = layer.get_config()
|
||||
|
||||
layer = keras.layers.deserialize({
|
||||
'class_name': 'Lambda',
|
||||
'config': config
|
||||
})
|
||||
self.assertAllEqual(layer.function(1), 2)
|
||||
self.assertAllEqual(layer._output_shape, (1, 1))
|
||||
self.assertAllEqual(layer.mask(1, True), True)
|
||||
|
||||
layer = keras.layers.Lambda.from_config(config)
|
||||
self.assertAllEqual(layer.function(1), 2)
|
||||
self.assertAllEqual(layer._output_shape, (1, 1))
|
||||
self.assertAllEqual(layer.mask(1, True), True)
|
||||
|
||||
def test_lambda_with_variable(self):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user