Enrich update ops from inputs
PiperOrigin-RevId: 204223077
This commit is contained in:
parent
4d04403a3d
commit
5574d6041a
@ -599,7 +599,7 @@ class Model(Network):
|
||||
# Unconditional updates
|
||||
updates += self.get_updates_for(None)
|
||||
# Conditional updates relevant to this model
|
||||
updates += self.get_updates_for(self._feed_inputs)
|
||||
updates += self.get_updates_for(self.inputs)
|
||||
# Stateful metrics updates
|
||||
updates += self.metrics_updates
|
||||
# Gets loss and metrics. Updates weights at each call.
|
||||
|
@ -37,6 +37,7 @@ class TestModelCloning(test.TestCase):
|
||||
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(4, input_shape=(4,)))
|
||||
model.add(keras.layers.BatchNormalization())
|
||||
model.add(keras.layers.Dropout(0.5))
|
||||
model.add(keras.layers.Dense(4))
|
||||
|
||||
@ -46,6 +47,8 @@ class TestModelCloning(test.TestCase):
|
||||
with self.test_session():
|
||||
# With placeholder creation
|
||||
new_model = keras.models.clone_model(model)
|
||||
# update ops from batch norm needs to be included
|
||||
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
|
||||
new_model.compile('rmsprop', 'mse')
|
||||
new_model.train_on_batch(val_a, val_out)
|
||||
|
||||
@ -53,6 +56,7 @@ class TestModelCloning(test.TestCase):
|
||||
input_a = keras.Input(shape=(4,))
|
||||
new_model = keras.models.clone_model(
|
||||
model, input_tensors=input_a)
|
||||
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
|
||||
new_model.compile('rmsprop', 'mse')
|
||||
new_model.train_on_batch(val_a, val_out)
|
||||
|
||||
@ -60,6 +64,7 @@ class TestModelCloning(test.TestCase):
|
||||
input_a = keras.backend.variable(val_a)
|
||||
new_model = keras.models.clone_model(
|
||||
model, input_tensors=input_a)
|
||||
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
|
||||
new_model.compile('rmsprop', 'mse')
|
||||
new_model.train_on_batch(None, val_out)
|
||||
|
||||
@ -76,6 +81,7 @@ class TestModelCloning(test.TestCase):
|
||||
|
||||
x_a = dense_1(input_a)
|
||||
x_a = keras.layers.Dropout(0.5)(x_a)
|
||||
x_a = keras.layers.BatchNormalization()(x_a)
|
||||
x_b = dense_1(input_b)
|
||||
x_a = dense_2(x_a)
|
||||
outputs = keras.layers.add([x_a, x_b])
|
||||
@ -87,6 +93,7 @@ class TestModelCloning(test.TestCase):
|
||||
with self.test_session():
|
||||
# With placeholder creation
|
||||
new_model = keras.models.clone_model(model)
|
||||
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
|
||||
new_model.compile('rmsprop', 'mse')
|
||||
new_model.train_on_batch([val_a, val_b], val_out)
|
||||
|
||||
@ -95,6 +102,7 @@ class TestModelCloning(test.TestCase):
|
||||
input_b = keras.Input(shape=(4,), name='b')
|
||||
new_model = keras.models.clone_model(
|
||||
model, input_tensors=[input_a, input_b])
|
||||
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
|
||||
new_model.compile('rmsprop', 'mse')
|
||||
new_model.train_on_batch([val_a, val_b], val_out)
|
||||
|
||||
@ -103,6 +111,7 @@ class TestModelCloning(test.TestCase):
|
||||
input_b = keras.backend.variable(val_b)
|
||||
new_model = keras.models.clone_model(
|
||||
model, input_tensors=[input_a, input_b])
|
||||
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
|
||||
new_model.compile('rmsprop', 'mse')
|
||||
new_model.train_on_batch(None, val_out)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user