From 8f68aad1107df679843da96a990773e9fc30201c Mon Sep 17 00:00:00 2001
From: Scott Zhu <scottzhu@google.com>
Date: Thu, 19 Nov 2020 16:13:09 -0800
Subject: [PATCH] Fix functional subclass model used with multiple inheritance.

Fix https://github.com/tensorflow/tensorflow/issues/44646

PiperOrigin-RevId: 343390846
Change-Id: I6f621fb3c70efa1f4181fe08bccd7df2bd5ffdab
---
 .../python/keras/engine/functional_test.py    | 70 +++++++++++++++++++
 tensorflow/python/keras/engine/training.py    | 31 +++++++-
 2 files changed, 100 insertions(+), 1 deletion(-)

diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py
index 8427517f235..fea3ee16da7 100644
--- a/tensorflow/python/keras/engine/functional_test.py
+++ b/tensorflow/python/keras/engine/functional_test.py
@@ -2473,5 +2473,75 @@ class InputsOutputsErrorTest(keras_parameterized.TestCase):
       model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))})
 
 
+class FunctionalSubclassModel(training_lib.Model):
+
+  def __init__(self, *args, **kwargs):
+    my_input = input_layer_lib.Input(shape=(16,))
+    dense = layers.Dense(32, activation='relu')
+    output = dense(my_input)
+    outputs = {'output': output}
+    super().__init__(inputs=[my_input], outputs=outputs, *args, **kwargs)
+
+
+class MixinClass(object):
+
+  def __init__(self, foo, **kwargs):
+    self._foo = foo
+    super().__init__(**kwargs)
+
+  def get_foo(self):
+    return self._foo
+
+
+class SubclassedModel(training_lib.Model):
+
+  def __init__(self, bar, **kwargs):
+    self._bar = bar
+    super().__init__(**kwargs)
+
+  def get_bar(self):
+    return self._bar
+
+
+class MultipleInheritanceModelTest(keras_parameterized.TestCase):
+
+  def testFunctionalSubclass(self):
+    m = FunctionalSubclassModel()
+    # Some smoke test for the weights and output shape of the model
+    self.assertLen(m.weights, 2)
+    self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
+
+  def testFunctionalSubclassPreMixin(self):
+    class MixedFunctionalSubclassModel(MixinClass, FunctionalSubclassModel):
+      pass
+
+    m = MixedFunctionalSubclassModel(foo='123')
+    self.assertTrue(m._is_graph_network)
+    self.assertLen(m.weights, 2)
+    self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
+    self.assertEqual(m.get_foo(), '123')
+
+  def testFunctionalSubclassPostMixin(self):
+    # Make sure the the mixin class is also init correct when the order changed.
+
+    class MixedFunctionalSubclassModel(FunctionalSubclassModel, MixinClass):
+      pass
+
+    m = MixedFunctionalSubclassModel(foo='123')
+    self.assertTrue(m._is_graph_network)
+    self.assertLen(m.weights, 2)
+    self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
+    self.assertEqual(m.get_foo(), '123')
+
+  def testSubclassModelPreMixin(self):
+    class MixedSubclassModel(MixinClass, SubclassedModel):
+      pass
+
+    m = MixedSubclassModel(foo='123', bar='456')
+    self.assertFalse(m._is_graph_network)
+    self.assertEqual(m.get_foo(), '123')
+    self.assertEqual(m.get_bar(), '456')
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 55f71e3a94c..3feb39172f4 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -115,6 +115,10 @@ def inject_functional_model_class(cls):
   from tensorflow.python.keras.engine import training_v1  # pylint: disable=g-import-not-at-top
   if cls == Model or cls == training_v1.Model:
     return functional.Functional
+  # In case there is any multiple inheritance, we stop injecting the
+  # class if keras model is not in its class hierarchy.
+  if cls == object:
+    return object
 
   cls.__bases__ = tuple(inject_functional_model_class(base)
                         for base in cls.__bases__)
@@ -230,8 +234,33 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
     from tensorflow.python.keras.engine import functional  # pylint: disable=g-import-not-at-top
     if (is_functional_model_init_params(args, kwargs) and
         not isinstance(self, functional.Functional)):
+      # Filter the kwargs for multiple inheritance.
+      supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init']
+      model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs}
+      other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs}
       inject_functional_model_class(self.__class__)
-      functional.Functional.__init__(self, *args, **kwargs)
+      functional.Functional.__init__(self, *args, **model_kwargs)
+
+      # In case there is any multiple inheritance here, we need to call the
+      # __init__ for any class that appears after the Functional class.
+      clz_to_init = []
+      found_functional_class = False
+      for clz in self.__class__.__bases__:
+        if issubclass(clz, functional.Functional):
+          found_functional_class = True
+          continue
+        if found_functional_class:
+          clz_to_init.append(clz)
+
+      if clz_to_init:
+        for clz in clz_to_init:
+          clz.__init__(self, *args, **other_kwargs)
+      elif other_kwargs:
+        # In case there are unused kwargs, we should raise an error to user, in
+        # case they have a typo in the param name.
+        raise TypeError(
+            'The following keyword arguments aren\'t supported: {}'.format(
+                other_kwargs))
       return
 
     base_layer.keras_api_gauge.get_cell('Model subclass').set(True)