Fix functional subclass model used with multiple inheritance.

Fix https://github.com/tensorflow/tensorflow/issues/44646

PiperOrigin-RevId: 343390846
Change-Id: I6f621fb3c70efa1f4181fe08bccd7df2bd5ffdab
This commit is contained in:
Scott Zhu 2020-11-19 16:13:09 -08:00 committed by TensorFlower Gardener
parent e21825b8b1
commit 8f68aad110
2 changed files with 100 additions and 1 deletions

View File

@ -2473,5 +2473,75 @@ class InputsOutputsErrorTest(keras_parameterized.TestCase):
model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))})
class FunctionalSubclassModel(training_lib.Model):
def __init__(self, *args, **kwargs):
my_input = input_layer_lib.Input(shape=(16,))
dense = layers.Dense(32, activation='relu')
output = dense(my_input)
outputs = {'output': output}
super().__init__(inputs=[my_input], outputs=outputs, *args, **kwargs)
class MixinClass(object):
def __init__(self, foo, **kwargs):
self._foo = foo
super().__init__(**kwargs)
def get_foo(self):
return self._foo
class SubclassedModel(training_lib.Model):
def __init__(self, bar, **kwargs):
self._bar = bar
super().__init__(**kwargs)
def get_bar(self):
return self._bar
class MultipleInheritanceModelTest(keras_parameterized.TestCase):
def testFunctionalSubclass(self):
m = FunctionalSubclassModel()
# Some smoke test for the weights and output shape of the model
self.assertLen(m.weights, 2)
self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
def testFunctionalSubclassPreMixin(self):
class MixedFunctionalSubclassModel(MixinClass, FunctionalSubclassModel):
pass
m = MixedFunctionalSubclassModel(foo='123')
self.assertTrue(m._is_graph_network)
self.assertLen(m.weights, 2)
self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
self.assertEqual(m.get_foo(), '123')
def testFunctionalSubclassPostMixin(self):
# Make sure the the mixin class is also init correct when the order changed.
class MixedFunctionalSubclassModel(FunctionalSubclassModel, MixinClass):
pass
m = MixedFunctionalSubclassModel(foo='123')
self.assertTrue(m._is_graph_network)
self.assertLen(m.weights, 2)
self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
self.assertEqual(m.get_foo(), '123')
def testSubclassModelPreMixin(self):
class MixedSubclassModel(MixinClass, SubclassedModel):
pass
m = MixedSubclassModel(foo='123', bar='456')
self.assertFalse(m._is_graph_network)
self.assertEqual(m.get_foo(), '123')
self.assertEqual(m.get_bar(), '456')
if __name__ == '__main__':
test.main()

View File

@ -115,6 +115,10 @@ def inject_functional_model_class(cls):
from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top
if cls == Model or cls == training_v1.Model:
return functional.Functional
# In case there is any multiple inheritance, we stop injecting the
# class if keras model is not in its class hierarchy.
if cls == object:
return object
cls.__bases__ = tuple(inject_functional_model_class(base)
for base in cls.__bases__)
@ -230,8 +234,33 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
if (is_functional_model_init_params(args, kwargs) and
not isinstance(self, functional.Functional)):
# Filter the kwargs for multiple inheritance.
supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init']
model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs}
other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs}
inject_functional_model_class(self.__class__)
functional.Functional.__init__(self, *args, **kwargs)
functional.Functional.__init__(self, *args, **model_kwargs)
# In case there is any multiple inheritance here, we need to call the
# __init__ for any class that appears after the Functional class.
clz_to_init = []
found_functional_class = False
for clz in self.__class__.__bases__:
if issubclass(clz, functional.Functional):
found_functional_class = True
continue
if found_functional_class:
clz_to_init.append(clz)
if clz_to_init:
for clz in clz_to_init:
clz.__init__(self, *args, **other_kwargs)
elif other_kwargs:
# In case there are unused kwargs, we should raise an error to user, in
# case they have a typo in the param name.
raise TypeError(
'The following keyword arguments aren\'t supported: {}'.format(
other_kwargs))
return
base_layer.keras_api_gauge.get_cell('Model subclass').set(True)