diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 942f8035530..f633afaae5b 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -1397,3 +1397,48 @@ def shape_with_no_batch_size(x): if shape: shape[0] = None return shape + + +class ModuleWrapper(base_layer.Layer): + """Wrapper for `tf.Module`s to support the Functional and Sequential API.""" + + def __init__(self, module, method_name=None, **kwargs): + """Initializes the wrapper Layer for this module. + + Arguments: + module: The `tf.Module` instance to be wrapped. + method_name: (Optional) str. The name of the method to use as the forward + pass of the module. If not set, defaults to '__call__' if defined, or + 'call'. + **kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`. + + Raises: + ValueError: If `method` is not defined on `module`. + """ + super(ModuleWrapper, self).__init__(**kwargs) + if method_name is None: + if hasattr(module, '__call__'): + method_name = '__call__' + elif hasattr(module, 'call'): + method_name = 'call' + if method_name is None or not hasattr(module, method_name): + raise ValueError('{} is not defined on object {}'.format( + method_name, module)) + + self._module = module + self._method_name = method_name + + # Check if module.__call__ has a `training` arg or accepts `**kwargs`. + method = getattr(module, method_name) + method_arg_spec = tf_inspect.getfullargspec(method) + self._expects_training_arg = ('training' in method_arg_spec.args or + method_arg_spec.varkw is not None) + self._expects_mask_arg = ('mask' in method_arg_spec.args or + method_arg_spec.varkw is not None) + + def call(self, *args, **kwargs): + if 'training' in kwargs and not self._expects_training_arg: + kwargs.pop('training') + if 'mask' in kwargs and not self._expects_mask_arg: + kwargs.pop('mask') + return getattr(self._module, self._method_name)(*args, **kwargs) diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index a54f803cc96..eff921ee3c1 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -35,6 +35,7 @@ from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.module import module from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.tracking import base as trackable @@ -178,7 +179,10 @@ class Sequential(functional.Functional): if isinstance(origin_layer, input_layer.InputLayer): layer = origin_layer - if not isinstance(layer, base_layer.Layer): + if isinstance(layer, module.Module): + if not isinstance(layer, base_layer.Layer): + layer = functional.ModuleWrapper(layer) + else: raise TypeError('The added layer must be ' 'an instance of class Layer. ' 'Found: ' + str(layer)) diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 39fcb2ef5a3..23e97d06335 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -30,7 +30,9 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils +from tensorflow.python.module import module from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -457,6 +459,56 @@ class TestSequential(keras_parameterized.TestCase): with self.assertRaisesRegex(ValueError, 'should have unique names'): model.add(keras.layers.Dense(3, name='specific_name')) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_tf_module_call(self): + + class MyModule(module.Module): + + def __init__(self): + self.v = variables.Variable(2.) + + def __call__(self, x): + return self.v * x + + model = keras.Sequential() + model.add(MyModule()) + model.compile('sgd', 'mse') + x, y = np.ones((10, 1)), np.ones((10, 1)) + model.fit(x, y, batch_size=2) + self.assertLen(model.trainable_variables, 1) + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_tf_module_training(self): + + class MyModule(module.Module): + + def __init__(self): + self.v = variables.Variable(2.) + + def call(self, x, training=None): + # training should be set by Sequential. + assert training is not None + return self.v * x + + model = keras.Sequential() + model.add(MyModule()) + model.compile('sgd', 'mse') + x, y = np.ones((10, 1)), np.ones((10, 1)) + model.fit(x, y, batch_size=2) + self.assertLen(model.trainable_variables, 1) + + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_tf_module_error(self): + + class MyModule(module.Module): + + def __init__(self): + self.v = variables.Variable(2.) + + model = keras.Sequential() + with self.assertRaisesRegex(ValueError, 'is not defined'): + model.add(MyModule()) + class TestSequentialEagerIntegration(keras_parameterized.TestCase):