Fix mask function serialization for lambda layer.
This was causing issue in model_to_estimator where model was serialize and reconstructed, the mask function was silently dropped in that case. PiperOrigin-RevId: 246888369
This commit is contained in:
parent
861d5471c6
commit
2707117ab8
@ -803,87 +803,66 @@ class Lambda(Layer):
|
|||||||
return self.mask
|
return self.mask
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
module = self.function.__module__
|
function_config = self._serialize_function_to_config(self.function)
|
||||||
if isinstance(self.function, python_types.LambdaType):
|
output_shape_config = self._serialize_function_to_config(self._output_shape,
|
||||||
function = generic_utils.func_dump(self.function)
|
allow_raw=True)
|
||||||
function_type = 'lambda'
|
|
||||||
else:
|
|
||||||
function = self.function.__name__
|
|
||||||
function_type = 'function'
|
|
||||||
|
|
||||||
output_shape_module = None
|
|
||||||
if isinstance(self._output_shape, python_types.LambdaType):
|
|
||||||
output_shape = generic_utils.func_dump(self._output_shape)
|
|
||||||
output_shape_type = 'lambda'
|
|
||||||
output_shape_module = self._output_shape.__module__
|
|
||||||
elif callable(self._output_shape):
|
|
||||||
output_shape = self._output_shape.__name__
|
|
||||||
output_shape_type = 'function'
|
|
||||||
output_shape_module = self._output_shape.__module__
|
|
||||||
else:
|
|
||||||
output_shape = self._output_shape
|
|
||||||
output_shape_type = 'raw'
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'function': function,
|
'function': function_config[0],
|
||||||
'module': module,
|
'function_type': function_config[1],
|
||||||
'function_type': function_type,
|
'module': function_config[2],
|
||||||
'output_shape': output_shape,
|
'output_shape': output_shape_config[0],
|
||||||
'output_shape_type': output_shape_type,
|
'output_shape_type': output_shape_config[1],
|
||||||
'output_shape_module': output_shape_module,
|
'output_shape_module': output_shape_config[2],
|
||||||
'arguments': self.arguments
|
|
||||||
}
|
}
|
||||||
|
if self.mask is not None:
|
||||||
|
mask_config = self._serialize_function_to_config(self.mask)
|
||||||
|
config.update({
|
||||||
|
'mask': mask_config[0],
|
||||||
|
'mask_type': mask_config[1],
|
||||||
|
'mask_module': mask_config[2]
|
||||||
|
})
|
||||||
|
config['arguments'] = self.arguments
|
||||||
|
|
||||||
base_config = super(Lambda, self).get_config()
|
base_config = super(Lambda, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
def _serialize_function_to_config(self, inputs, allow_raw=False):
|
||||||
|
if isinstance(inputs, python_types.LambdaType):
|
||||||
|
output = generic_utils.func_dump(inputs)
|
||||||
|
output_type = 'lambda'
|
||||||
|
module = inputs.__module__
|
||||||
|
elif callable(inputs):
|
||||||
|
output = inputs.__name__
|
||||||
|
output_type = 'function'
|
||||||
|
module = inputs.__module__
|
||||||
|
elif allow_raw:
|
||||||
|
output = inputs
|
||||||
|
output_type = 'raw'
|
||||||
|
module = None
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Invalid input for serialization, type: %s ' % type(inputs))
|
||||||
|
|
||||||
|
return output, output_type, module
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config, custom_objects=None):
|
def from_config(cls, config, custom_objects=None):
|
||||||
config = config.copy()
|
config = config.copy()
|
||||||
globs = globals()
|
function = cls._parse_function_from_config(
|
||||||
module = config.pop('module', None)
|
config, custom_objects, 'function', 'module', 'function_type')
|
||||||
if module in sys.modules:
|
|
||||||
globs.update(sys.modules[module].__dict__)
|
|
||||||
elif module is not None:
|
|
||||||
# Note: we don't know the name of the function if it's a lambda.
|
|
||||||
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
|
|
||||||
'It may cause errors.'.format(module)
|
|
||||||
, UserWarning)
|
|
||||||
if custom_objects:
|
|
||||||
globs.update(custom_objects)
|
|
||||||
function_type = config.pop('function_type')
|
|
||||||
if function_type == 'function':
|
|
||||||
# Simple lookup in custom objects
|
|
||||||
function = generic_utils.deserialize_keras_object(
|
|
||||||
config['function'],
|
|
||||||
custom_objects=custom_objects,
|
|
||||||
printable_module_name='function in Lambda layer')
|
|
||||||
elif function_type == 'lambda':
|
|
||||||
# Unsafe deserialization from bytecode
|
|
||||||
function = generic_utils.func_load(config['function'], globs=globs)
|
|
||||||
else:
|
|
||||||
raise TypeError('Unknown function type:', function_type)
|
|
||||||
|
|
||||||
output_shape_module = config.pop('output_shape_module', None)
|
output_shape = cls._parse_function_from_config(
|
||||||
if output_shape_module in sys.modules:
|
config, custom_objects, 'output_shape', 'output_shape_module',
|
||||||
globs.update(sys.modules[output_shape_module].__dict__)
|
'output_shape_type')
|
||||||
elif output_shape_module is not None:
|
if 'mask' in config:
|
||||||
# Note: we don't know the name of the function if it's a lambda.
|
mask = cls._parse_function_from_config(
|
||||||
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
|
config, custom_objects, 'mask', 'mask_module', 'mask_type')
|
||||||
'It may cause errors.'.format(output_shape_module)
|
|
||||||
, UserWarning)
|
|
||||||
output_shape_type = config.pop('output_shape_type')
|
|
||||||
if output_shape_type == 'function':
|
|
||||||
# Simple lookup in custom objects
|
|
||||||
output_shape = generic_utils.deserialize_keras_object(
|
|
||||||
config['output_shape'],
|
|
||||||
custom_objects=custom_objects,
|
|
||||||
printable_module_name='output_shape function in Lambda layer')
|
|
||||||
elif output_shape_type == 'lambda':
|
|
||||||
# Unsafe deserialization from bytecode
|
|
||||||
output_shape = generic_utils.func_load(config['output_shape'],
|
|
||||||
globs=globs)
|
|
||||||
else:
|
else:
|
||||||
output_shape = config['output_shape']
|
mask = None
|
||||||
|
|
||||||
|
config['function'] = function
|
||||||
|
config['output_shape'] = output_shape
|
||||||
|
config['mask'] = mask
|
||||||
|
|
||||||
# If arguments were numpy array, they have been saved as
|
# If arguments were numpy array, they have been saved as
|
||||||
# list. We need to recover the ndarray
|
# list. We need to recover the ndarray
|
||||||
@ -895,10 +874,40 @@ class Lambda(Layer):
|
|||||||
# Overwrite the argument with its numpy translation
|
# Overwrite the argument with its numpy translation
|
||||||
config['arguments'][key] = np.array(arg_dict['value'])
|
config['arguments'][key] = np.array(arg_dict['value'])
|
||||||
|
|
||||||
config['function'] = function
|
|
||||||
config['output_shape'] = output_shape
|
|
||||||
return cls(**config)
|
return cls(**config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_function_from_config(
|
||||||
|
cls, config, custom_objects, func_attr_name, module_attr_name,
|
||||||
|
func_type_attr_name):
|
||||||
|
globs = globals()
|
||||||
|
module = config.pop(module_attr_name, None)
|
||||||
|
if module in sys.modules:
|
||||||
|
globs.update(sys.modules[module].__dict__)
|
||||||
|
elif module is not None:
|
||||||
|
# Note: we don't know the name of the function if it's a lambda.
|
||||||
|
warnings.warn('{} is not loaded, but a Lambda layer uses it. '
|
||||||
|
'It may cause errors.'.format(module)
|
||||||
|
, UserWarning)
|
||||||
|
if custom_objects:
|
||||||
|
globs.update(custom_objects)
|
||||||
|
function_type = config.pop(func_type_attr_name)
|
||||||
|
if function_type == 'function':
|
||||||
|
# Simple lookup in custom objects
|
||||||
|
function = generic_utils.deserialize_keras_object(
|
||||||
|
config[func_attr_name],
|
||||||
|
custom_objects=custom_objects,
|
||||||
|
printable_module_name='function in Lambda layer')
|
||||||
|
elif function_type == 'lambda':
|
||||||
|
# Unsafe deserialization from bytecode
|
||||||
|
function = generic_utils.func_load(
|
||||||
|
config[func_attr_name], globs=globs)
|
||||||
|
elif function_type == 'raw':
|
||||||
|
function = config[func_attr_name]
|
||||||
|
else:
|
||||||
|
raise TypeError('Unknown function type:', function_type)
|
||||||
|
return function
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.layers.Dense')
|
@keras_export('keras.layers.Dense')
|
||||||
class Dense(Layer):
|
class Dense(Layer):
|
||||||
|
@ -107,12 +107,14 @@ class LambdaLayerTest(keras_parameterized.TestCase):
|
|||||||
'class_name': 'Lambda',
|
'class_name': 'Lambda',
|
||||||
'config': config
|
'config': config
|
||||||
})
|
})
|
||||||
|
self.assertEqual(ld.function(3), 4)
|
||||||
|
|
||||||
# test with lambda
|
# test with lambda
|
||||||
ld = keras.layers.Lambda(
|
ld = keras.layers.Lambda(
|
||||||
lambda x: keras.backend.concatenate([math_ops.square(x), x]))
|
lambda x: keras.backend.concatenate([math_ops.square(x), x]))
|
||||||
config = ld.get_config()
|
config = ld.get_config()
|
||||||
ld = keras.layers.Lambda.from_config(config)
|
ld = keras.layers.Lambda.from_config(config)
|
||||||
|
self.assertAllEqual(self.evaluate(ld.function([3])), [9, 3])
|
||||||
|
|
||||||
def test_lambda_multiple_inputs(self):
|
def test_lambda_multiple_inputs(self):
|
||||||
ld = keras.layers.Lambda(lambda x: x[0], output_shape=lambda x: x[0])
|
ld = keras.layers.Lambda(lambda x: x[0], output_shape=lambda x: x[0])
|
||||||
@ -184,14 +186,25 @@ class LambdaLayerTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def test_lambda_config_serialization(self):
|
def test_lambda_config_serialization(self):
|
||||||
# Test serialization with output_shape and output_shape_type
|
# Test serialization with output_shape and output_shape_type
|
||||||
layer = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
|
layer = keras.layers.Lambda(
|
||||||
|
lambda x: x + 1,
|
||||||
|
output_shape=(1, 1),
|
||||||
|
mask=lambda i, m: m)
|
||||||
layer(keras.backend.variable(np.ones((1, 1))))
|
layer(keras.backend.variable(np.ones((1, 1))))
|
||||||
config = layer.get_config()
|
config = layer.get_config()
|
||||||
|
|
||||||
layer = keras.layers.deserialize({
|
layer = keras.layers.deserialize({
|
||||||
'class_name': 'Lambda',
|
'class_name': 'Lambda',
|
||||||
'config': config
|
'config': config
|
||||||
})
|
})
|
||||||
|
self.assertAllEqual(layer.function(1), 2)
|
||||||
|
self.assertAllEqual(layer._output_shape, (1, 1))
|
||||||
|
self.assertAllEqual(layer.mask(1, True), True)
|
||||||
|
|
||||||
layer = keras.layers.Lambda.from_config(config)
|
layer = keras.layers.Lambda.from_config(config)
|
||||||
|
self.assertAllEqual(layer.function(1), 2)
|
||||||
|
self.assertAllEqual(layer._output_shape, (1, 1))
|
||||||
|
self.assertAllEqual(layer.mask(1, True), True)
|
||||||
|
|
||||||
def test_lambda_with_variable(self):
|
def test_lambda_with_variable(self):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user