diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index 8427517f235..fea3ee16da7 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -2473,5 +2473,75 @@ class InputsOutputsErrorTest(keras_parameterized.TestCase): model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))}) +class FunctionalSubclassModel(training_lib.Model): + + def __init__(self, *args, **kwargs): + my_input = input_layer_lib.Input(shape=(16,)) + dense = layers.Dense(32, activation='relu') + output = dense(my_input) + outputs = {'output': output} + super().__init__(inputs=[my_input], outputs=outputs, *args, **kwargs) + + +class MixinClass(object): + + def __init__(self, foo, **kwargs): + self._foo = foo + super().__init__(**kwargs) + + def get_foo(self): + return self._foo + + +class SubclassedModel(training_lib.Model): + + def __init__(self, bar, **kwargs): + self._bar = bar + super().__init__(**kwargs) + + def get_bar(self): + return self._bar + + +class MultipleInheritanceModelTest(keras_parameterized.TestCase): + + def testFunctionalSubclass(self): + m = FunctionalSubclassModel() + # Some smoke test for the weights and output shape of the model + self.assertLen(m.weights, 2) + self.assertEqual(m.outputs[0].shape.as_list(), [None, 32]) + + def testFunctionalSubclassPreMixin(self): + class MixedFunctionalSubclassModel(MixinClass, FunctionalSubclassModel): + pass + + m = MixedFunctionalSubclassModel(foo='123') + self.assertTrue(m._is_graph_network) + self.assertLen(m.weights, 2) + self.assertEqual(m.outputs[0].shape.as_list(), [None, 32]) + self.assertEqual(m.get_foo(), '123') + + def testFunctionalSubclassPostMixin(self): + # Make sure the the mixin class is also init correct when the order changed. + + class MixedFunctionalSubclassModel(FunctionalSubclassModel, MixinClass): + pass + + m = MixedFunctionalSubclassModel(foo='123') + self.assertTrue(m._is_graph_network) + self.assertLen(m.weights, 2) + self.assertEqual(m.outputs[0].shape.as_list(), [None, 32]) + self.assertEqual(m.get_foo(), '123') + + def testSubclassModelPreMixin(self): + class MixedSubclassModel(MixinClass, SubclassedModel): + pass + + m = MixedSubclassModel(foo='123', bar='456') + self.assertFalse(m._is_graph_network) + self.assertEqual(m.get_foo(), '123') + self.assertEqual(m.get_bar(), '456') + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 55f71e3a94c..3feb39172f4 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -115,6 +115,10 @@ def inject_functional_model_class(cls): from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top if cls == Model or cls == training_v1.Model: return functional.Functional + # In case there is any multiple inheritance, we stop injecting the + # class if keras model is not in its class hierarchy. + if cls == object: + return object cls.__bases__ = tuple(inject_functional_model_class(base) for base in cls.__bases__) @@ -230,8 +234,33 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top if (is_functional_model_init_params(args, kwargs) and not isinstance(self, functional.Functional)): + # Filter the kwargs for multiple inheritance. + supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init'] + model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs} + other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs} inject_functional_model_class(self.__class__) - functional.Functional.__init__(self, *args, **kwargs) + functional.Functional.__init__(self, *args, **model_kwargs) + + # In case there is any multiple inheritance here, we need to call the + # __init__ for any class that appears after the Functional class. + clz_to_init = [] + found_functional_class = False + for clz in self.__class__.__bases__: + if issubclass(clz, functional.Functional): + found_functional_class = True + continue + if found_functional_class: + clz_to_init.append(clz) + + if clz_to_init: + for clz in clz_to_init: + clz.__init__(self, *args, **other_kwargs) + elif other_kwargs: + # In case there are unused kwargs, we should raise an error to user, in + # case they have a typo in the param name. + raise TypeError( + 'The following keyword arguments aren\'t supported: {}'.format( + other_kwargs)) return base_layer.keras_api_gauge.get_cell('Model subclass').set(True)