Enrich update ops from inputs

PiperOrigin-RevId: 204223077
This commit is contained in:
Zhenyu Tan 2018-07-11 17:55:57 -07:00 committed by TensorFlower Gardener
parent 4d04403a3d
commit 5574d6041a
2 changed files with 10 additions and 1 deletions

View File

@ -599,7 +599,7 @@ class Model(Network):
# Unconditional updates
updates += self.get_updates_for(None)
# Conditional updates relevant to this model
updates += self.get_updates_for(self._feed_inputs)
updates += self.get_updates_for(self.inputs)
# Stateful metrics updates
updates += self.metrics_updates
# Gets loss and metrics. Updates weights at each call.

View File

@ -37,6 +37,7 @@ class TestModelCloning(test.TestCase):
model = keras.models.Sequential()
model.add(keras.layers.Dense(4, input_shape=(4,)))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Dropout(0.5))
model.add(keras.layers.Dense(4))
@ -46,6 +47,8 @@ class TestModelCloning(test.TestCase):
with self.test_session():
# With placeholder creation
new_model = keras.models.clone_model(model)
# update ops from batch norm needs to be included
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(val_a, val_out)
@ -53,6 +56,7 @@ class TestModelCloning(test.TestCase):
input_a = keras.Input(shape=(4,))
new_model = keras.models.clone_model(
model, input_tensors=input_a)
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(val_a, val_out)
@ -60,6 +64,7 @@ class TestModelCloning(test.TestCase):
input_a = keras.backend.variable(val_a)
new_model = keras.models.clone_model(
model, input_tensors=input_a)
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(None, val_out)
@ -76,6 +81,7 @@ class TestModelCloning(test.TestCase):
x_a = dense_1(input_a)
x_a = keras.layers.Dropout(0.5)(x_a)
x_a = keras.layers.BatchNormalization()(x_a)
x_b = dense_1(input_b)
x_a = dense_2(x_a)
outputs = keras.layers.add([x_a, x_b])
@ -87,6 +93,7 @@ class TestModelCloning(test.TestCase):
with self.test_session():
# With placeholder creation
new_model = keras.models.clone_model(model)
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch([val_a, val_b], val_out)
@ -95,6 +102,7 @@ class TestModelCloning(test.TestCase):
input_b = keras.Input(shape=(4,), name='b')
new_model = keras.models.clone_model(
model, input_tensors=[input_a, input_b])
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch([val_a, val_b], val_out)
@ -103,6 +111,7 @@ class TestModelCloning(test.TestCase):
input_b = keras.backend.variable(val_b)
new_model = keras.models.clone_model(
model, input_tensors=[input_a, input_b])
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(None, val_out)