Support tf.Modules in keras.Sequential.
Adds a private ModuleWrapper class that wraps the tf.Module that the user passes to keras.Sequential. keras.Sequential models created in this way support SavedModel saving, but do not support HDF5 saving format. PiperOrigin-RevId: 342340478 Change-Id: I0f9d44b3d73fa00bed4f3c474750269b9ec10fb0
This commit is contained in:
parent
f649375784
commit
3fd34b6ac3
@ -1397,3 +1397,48 @@ def shape_with_no_batch_size(x):
|
||||
if shape:
|
||||
shape[0] = None
|
||||
return shape
|
||||
|
||||
|
||||
class ModuleWrapper(base_layer.Layer):
|
||||
"""Wrapper for `tf.Module`s to support the Functional and Sequential API."""
|
||||
|
||||
def __init__(self, module, method_name=None, **kwargs):
|
||||
"""Initializes the wrapper Layer for this module.
|
||||
|
||||
Arguments:
|
||||
module: The `tf.Module` instance to be wrapped.
|
||||
method_name: (Optional) str. The name of the method to use as the forward
|
||||
pass of the module. If not set, defaults to '__call__' if defined, or
|
||||
'call'.
|
||||
**kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `method` is not defined on `module`.
|
||||
"""
|
||||
super(ModuleWrapper, self).__init__(**kwargs)
|
||||
if method_name is None:
|
||||
if hasattr(module, '__call__'):
|
||||
method_name = '__call__'
|
||||
elif hasattr(module, 'call'):
|
||||
method_name = 'call'
|
||||
if method_name is None or not hasattr(module, method_name):
|
||||
raise ValueError('{} is not defined on object {}'.format(
|
||||
method_name, module))
|
||||
|
||||
self._module = module
|
||||
self._method_name = method_name
|
||||
|
||||
# Check if module.__call__ has a `training` arg or accepts `**kwargs`.
|
||||
method = getattr(module, method_name)
|
||||
method_arg_spec = tf_inspect.getfullargspec(method)
|
||||
self._expects_training_arg = ('training' in method_arg_spec.args or
|
||||
method_arg_spec.varkw is not None)
|
||||
self._expects_mask_arg = ('mask' in method_arg_spec.args or
|
||||
method_arg_spec.varkw is not None)
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
if 'training' in kwargs and not self._expects_training_arg:
|
||||
kwargs.pop('training')
|
||||
if 'mask' in kwargs and not self._expects_mask_arg:
|
||||
kwargs.pop('mask')
|
||||
return getattr(self._module, self._method_name)(*args, **kwargs)
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
@ -178,7 +179,10 @@ class Sequential(functional.Functional):
|
||||
if isinstance(origin_layer, input_layer.InputLayer):
|
||||
layer = origin_layer
|
||||
|
||||
if not isinstance(layer, base_layer.Layer):
|
||||
if isinstance(layer, module.Module):
|
||||
if not isinstance(layer, base_layer.Layer):
|
||||
layer = functional.ModuleWrapper(layer)
|
||||
else:
|
||||
raise TypeError('The added layer must be '
|
||||
'an instance of class Layer. '
|
||||
'Found: ' + str(layer))
|
||||
|
@ -30,7 +30,9 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -457,6 +459,56 @@ class TestSequential(keras_parameterized.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'should have unique names'):
|
||||
model.add(keras.layers.Dense(3, name='specific_name'))
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_tf_module_call(self):
|
||||
|
||||
class MyModule(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
self.v = variables.Variable(2.)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.v * x
|
||||
|
||||
model = keras.Sequential()
|
||||
model.add(MyModule())
|
||||
model.compile('sgd', 'mse')
|
||||
x, y = np.ones((10, 1)), np.ones((10, 1))
|
||||
model.fit(x, y, batch_size=2)
|
||||
self.assertLen(model.trainable_variables, 1)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_tf_module_training(self):
|
||||
|
||||
class MyModule(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
self.v = variables.Variable(2.)
|
||||
|
||||
def call(self, x, training=None):
|
||||
# training should be set by Sequential.
|
||||
assert training is not None
|
||||
return self.v * x
|
||||
|
||||
model = keras.Sequential()
|
||||
model.add(MyModule())
|
||||
model.compile('sgd', 'mse')
|
||||
x, y = np.ones((10, 1)), np.ones((10, 1))
|
||||
model.fit(x, y, batch_size=2)
|
||||
self.assertLen(model.trainable_variables, 1)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_tf_module_error(self):
|
||||
|
||||
class MyModule(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
self.v = variables.Variable(2.)
|
||||
|
||||
model = keras.Sequential()
|
||||
with self.assertRaisesRegex(ValueError, 'is not defined'):
|
||||
model.add(MyModule())
|
||||
|
||||
|
||||
class TestSequentialEagerIntegration(keras_parameterized.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user