Support tf.Modules in keras.Sequential.

Adds a private ModuleWrapper class that wraps the tf.Module that the user
passes to keras.Sequential. keras.Sequential models created in this way support
SavedModel saving, but do not support HDF5 saving format.

PiperOrigin-RevId: 342340478
Change-Id: I0f9d44b3d73fa00bed4f3c474750269b9ec10fb0
This commit is contained in:
Thomas O'Malley 2020-11-13 14:36:54 -08:00 committed by TensorFlower Gardener
parent f649375784
commit 3fd34b6ac3
3 changed files with 102 additions and 1 deletions

View File

@ -1397,3 +1397,48 @@ def shape_with_no_batch_size(x):
if shape:
shape[0] = None
return shape
class ModuleWrapper(base_layer.Layer):
"""Wrapper for `tf.Module`s to support the Functional and Sequential API."""
def __init__(self, module, method_name=None, **kwargs):
"""Initializes the wrapper Layer for this module.
Arguments:
module: The `tf.Module` instance to be wrapped.
method_name: (Optional) str. The name of the method to use as the forward
pass of the module. If not set, defaults to '__call__' if defined, or
'call'.
**kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`.
Raises:
ValueError: If `method` is not defined on `module`.
"""
super(ModuleWrapper, self).__init__(**kwargs)
if method_name is None:
if hasattr(module, '__call__'):
method_name = '__call__'
elif hasattr(module, 'call'):
method_name = 'call'
if method_name is None or not hasattr(module, method_name):
raise ValueError('{} is not defined on object {}'.format(
method_name, module))
self._module = module
self._method_name = method_name
# Check if module.__call__ has a `training` arg or accepts `**kwargs`.
method = getattr(module, method_name)
method_arg_spec = tf_inspect.getfullargspec(method)
self._expects_training_arg = ('training' in method_arg_spec.args or
method_arg_spec.varkw is not None)
self._expects_mask_arg = ('mask' in method_arg_spec.args or
method_arg_spec.varkw is not None)
def call(self, *args, **kwargs):
if 'training' in kwargs and not self._expects_training_arg:
kwargs.pop('training')
if 'mask' in kwargs and not self._expects_mask_arg:
kwargs.pop('mask')
return getattr(self._module, self._method_name)(*args, **kwargs)

View File

@ -35,6 +35,7 @@ from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.module import module
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable
@ -178,7 +179,10 @@ class Sequential(functional.Functional):
if isinstance(origin_layer, input_layer.InputLayer):
layer = origin_layer
if not isinstance(layer, base_layer.Layer):
if isinstance(layer, module.Module):
if not isinstance(layer, base_layer.Layer):
layer = functional.ModuleWrapper(layer)
else:
raise TypeError('The added layer must be '
'an instance of class Layer. '
'Found: ' + str(layer))

View File

@ -30,7 +30,9 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -457,6 +459,56 @@ class TestSequential(keras_parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'should have unique names'):
model.add(keras.layers.Dense(3, name='specific_name'))
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_tf_module_call(self):
class MyModule(module.Module):
def __init__(self):
self.v = variables.Variable(2.)
def __call__(self, x):
return self.v * x
model = keras.Sequential()
model.add(MyModule())
model.compile('sgd', 'mse')
x, y = np.ones((10, 1)), np.ones((10, 1))
model.fit(x, y, batch_size=2)
self.assertLen(model.trainable_variables, 1)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_tf_module_training(self):
class MyModule(module.Module):
def __init__(self):
self.v = variables.Variable(2.)
def call(self, x, training=None):
# training should be set by Sequential.
assert training is not None
return self.v * x
model = keras.Sequential()
model.add(MyModule())
model.compile('sgd', 'mse')
x, y = np.ones((10, 1)), np.ones((10, 1))
model.fit(x, y, batch_size=2)
self.assertLen(model.trainable_variables, 1)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_tf_module_error(self):
class MyModule(module.Module):
def __init__(self):
self.v = variables.Variable(2.)
model = keras.Sequential()
with self.assertRaisesRegex(ValueError, 'is not defined'):
model.add(MyModule())
class TestSequentialEagerIntegration(keras_parameterized.TestCase):