Fix functional subclass model used with multiple inheritance.
Fix https://github.com/tensorflow/tensorflow/issues/44646 PiperOrigin-RevId: 343390846 Change-Id: I6f621fb3c70efa1f4181fe08bccd7df2bd5ffdab
This commit is contained in:
parent
e21825b8b1
commit
8f68aad110
@ -2473,5 +2473,75 @@ class InputsOutputsErrorTest(keras_parameterized.TestCase):
|
||||
model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))})
|
||||
|
||||
|
||||
class FunctionalSubclassModel(training_lib.Model):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
my_input = input_layer_lib.Input(shape=(16,))
|
||||
dense = layers.Dense(32, activation='relu')
|
||||
output = dense(my_input)
|
||||
outputs = {'output': output}
|
||||
super().__init__(inputs=[my_input], outputs=outputs, *args, **kwargs)
|
||||
|
||||
|
||||
class MixinClass(object):
|
||||
|
||||
def __init__(self, foo, **kwargs):
|
||||
self._foo = foo
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_foo(self):
|
||||
return self._foo
|
||||
|
||||
|
||||
class SubclassedModel(training_lib.Model):
|
||||
|
||||
def __init__(self, bar, **kwargs):
|
||||
self._bar = bar
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_bar(self):
|
||||
return self._bar
|
||||
|
||||
|
||||
class MultipleInheritanceModelTest(keras_parameterized.TestCase):
|
||||
|
||||
def testFunctionalSubclass(self):
|
||||
m = FunctionalSubclassModel()
|
||||
# Some smoke test for the weights and output shape of the model
|
||||
self.assertLen(m.weights, 2)
|
||||
self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
|
||||
|
||||
def testFunctionalSubclassPreMixin(self):
|
||||
class MixedFunctionalSubclassModel(MixinClass, FunctionalSubclassModel):
|
||||
pass
|
||||
|
||||
m = MixedFunctionalSubclassModel(foo='123')
|
||||
self.assertTrue(m._is_graph_network)
|
||||
self.assertLen(m.weights, 2)
|
||||
self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
|
||||
self.assertEqual(m.get_foo(), '123')
|
||||
|
||||
def testFunctionalSubclassPostMixin(self):
|
||||
# Make sure the the mixin class is also init correct when the order changed.
|
||||
|
||||
class MixedFunctionalSubclassModel(FunctionalSubclassModel, MixinClass):
|
||||
pass
|
||||
|
||||
m = MixedFunctionalSubclassModel(foo='123')
|
||||
self.assertTrue(m._is_graph_network)
|
||||
self.assertLen(m.weights, 2)
|
||||
self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
|
||||
self.assertEqual(m.get_foo(), '123')
|
||||
|
||||
def testSubclassModelPreMixin(self):
|
||||
class MixedSubclassModel(MixinClass, SubclassedModel):
|
||||
pass
|
||||
|
||||
m = MixedSubclassModel(foo='123', bar='456')
|
||||
self.assertFalse(m._is_graph_network)
|
||||
self.assertEqual(m.get_foo(), '123')
|
||||
self.assertEqual(m.get_bar(), '456')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -115,6 +115,10 @@ def inject_functional_model_class(cls):
|
||||
from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top
|
||||
if cls == Model or cls == training_v1.Model:
|
||||
return functional.Functional
|
||||
# In case there is any multiple inheritance, we stop injecting the
|
||||
# class if keras model is not in its class hierarchy.
|
||||
if cls == object:
|
||||
return object
|
||||
|
||||
cls.__bases__ = tuple(inject_functional_model_class(base)
|
||||
for base in cls.__bases__)
|
||||
@ -230,8 +234,33 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top
|
||||
if (is_functional_model_init_params(args, kwargs) and
|
||||
not isinstance(self, functional.Functional)):
|
||||
# Filter the kwargs for multiple inheritance.
|
||||
supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init']
|
||||
model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs}
|
||||
other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs}
|
||||
inject_functional_model_class(self.__class__)
|
||||
functional.Functional.__init__(self, *args, **kwargs)
|
||||
functional.Functional.__init__(self, *args, **model_kwargs)
|
||||
|
||||
# In case there is any multiple inheritance here, we need to call the
|
||||
# __init__ for any class that appears after the Functional class.
|
||||
clz_to_init = []
|
||||
found_functional_class = False
|
||||
for clz in self.__class__.__bases__:
|
||||
if issubclass(clz, functional.Functional):
|
||||
found_functional_class = True
|
||||
continue
|
||||
if found_functional_class:
|
||||
clz_to_init.append(clz)
|
||||
|
||||
if clz_to_init:
|
||||
for clz in clz_to_init:
|
||||
clz.__init__(self, *args, **other_kwargs)
|
||||
elif other_kwargs:
|
||||
# In case there are unused kwargs, we should raise an error to user, in
|
||||
# case they have a typo in the param name.
|
||||
raise TypeError(
|
||||
'The following keyword arguments aren\'t supported: {}'.format(
|
||||
other_kwargs))
|
||||
return
|
||||
|
||||
base_layer.keras_api_gauge.get_cell('Model subclass').set(True)
|
||||
|
Loading…
Reference in New Issue
Block a user