Remove conditional loss/graph tracing based on inputs in v2 add_update/get_updates_for API.
PiperOrigin-RevId: 307688443 Change-Id: Ia2059a0582b869b425d518eadabccdfd2c85ced7
This commit is contained in:
parent
983a8df746
commit
ec3f58bc3c
@ -1484,9 +1484,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
dependent on `a` and some on `b`. This method automatically keeps track
|
||||
of dependencies.
|
||||
|
||||
The `get_updates_for` method allows to retrieve the updates relevant to a
|
||||
specific set of inputs.
|
||||
|
||||
This call is ignored when eager execution is enabled (in that case, variable
|
||||
updates are run on the fly and thus do not need to be tracked for later
|
||||
execution).
|
||||
@ -1518,12 +1515,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
update()
|
||||
return
|
||||
|
||||
if call_context.in_call:
|
||||
relevant_inputs = call_context.inputs
|
||||
else:
|
||||
inbound_nodes = getattr(self, '_inbound_nodes', [])
|
||||
relevant_inputs = [node.input_tensors for node in inbound_nodes]
|
||||
|
||||
def process_update(x):
|
||||
"""Standardize update ops.
|
||||
|
||||
@ -1545,9 +1536,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
update = x.op
|
||||
else:
|
||||
update = ops.convert_to_tensor_v2(x)
|
||||
|
||||
reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, [update])
|
||||
update._unconditional_update = update not in reachable
|
||||
return update
|
||||
|
||||
updates = [process_update(x) for x in updates]
|
||||
@ -1691,15 +1679,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
Returns:
|
||||
List of update ops of the layer that depend on `inputs`.
|
||||
"""
|
||||
if inputs is None:
|
||||
# Requesting unconditional updates.
|
||||
return [u for u in self.updates if u._unconditional_update]
|
||||
|
||||
# Requesting input-conditional updates.
|
||||
updates = [u for u in self.updates if not u._unconditional_update]
|
||||
inputs = nest.flatten(inputs)
|
||||
reachable = tf_utils.get_reachable_from_inputs(inputs, updates)
|
||||
return [u for u in updates if u in reachable]
|
||||
return self.updates
|
||||
|
||||
@doc_controls.do_not_doc_inheritable
|
||||
def get_losses_for(self, inputs):
|
||||
|
@ -83,21 +83,14 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
_ = layer(x1)
|
||||
|
||||
self.assertEqual(len(layer.updates), 2)
|
||||
self.assertEqual(len(layer.get_updates_for(x1)), 1)
|
||||
self.assertEqual(len(layer.get_updates_for(None)), 1)
|
||||
|
||||
x2 = input_layer_lib.Input(shape=(1,))
|
||||
y2 = layer(x2)
|
||||
|
||||
self.assertEqual(len(layer.updates), 3)
|
||||
self.assertEqual(len(layer.get_updates_for(x1)), 1)
|
||||
self.assertEqual(len(layer.get_updates_for(x2)), 1)
|
||||
self.assertEqual(len(layer.get_updates_for(None)), 1)
|
||||
|
||||
network = network_lib.Network(x2, y2)
|
||||
self.assertEqual(len(network.updates), 3)
|
||||
self.assertEqual(len(network.get_updates_for(x2)), 1)
|
||||
self.assertEqual(len(network.get_updates_for(None)), 1)
|
||||
|
||||
x3 = input_layer_lib.Input(shape=(1,))
|
||||
_ = layer(x3)
|
||||
@ -106,17 +99,12 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
x4 = input_layer_lib.Input(shape=(1,))
|
||||
_ = network(x4)
|
||||
self.assertEqual(len(network.updates), 5)
|
||||
self.assertEqual(len(network.get_updates_for(x2)), 1)
|
||||
self.assertEqual(len(network.get_updates_for(x4)), 1)
|
||||
self.assertEqual(len(network.get_updates_for(None)), 1)
|
||||
|
||||
network.add_update(state_ops.assign_add(layer.a, [[1]]))
|
||||
self.assertEqual(len(network.updates), 6)
|
||||
self.assertEqual(len(network.get_updates_for(None)), 2)
|
||||
|
||||
network.add_update(state_ops.assign_add(layer.b, x4), inputs=True)
|
||||
self.assertEqual(len(network.updates), 7)
|
||||
self.assertEqual(len(network.get_updates_for(x4)), 2)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_get_updates_bn(self):
|
||||
@ -125,8 +113,6 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
_ = layer(x1)
|
||||
|
||||
self.assertEqual(len(layer.updates), 2)
|
||||
self.assertEqual(len(layer.get_updates_for(x1)), 2)
|
||||
self.assertEqual(len(layer.get_updates_for(None)), 0)
|
||||
|
||||
def test_get_layer(self):
|
||||
# create a simple network
|
||||
@ -1572,7 +1558,6 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
||||
output_shape = network.compute_output_shape([(None, 1), (None, 1)])
|
||||
self.assertListEqual(output_shape.as_list(), [None, 1])
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_updates_with_direct_call(self):
|
||||
inputs = input_layer_lib.Input(shape=(10,))
|
||||
x = layers.BatchNormalization()(inputs)
|
||||
@ -1582,8 +1567,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
|
||||
ph = backend.placeholder(shape=(10, 10))
|
||||
model(ph)
|
||||
|
||||
self.assertLen(model.get_updates_for(ph), 2)
|
||||
self.assertLen(model.get_updates_for(None), 0)
|
||||
self.assertLen(model.updates, 4)
|
||||
|
||||
def test_dict_mapping_input(self):
|
||||
|
||||
|
@ -403,8 +403,6 @@ class NormalizationLayersGraphModeOnlyTest(
|
||||
model.train_on_batch(x, x)
|
||||
|
||||
self.assertLen(bn.updates, 4)
|
||||
self.assertLen(bn.get_updates_for(x1), 2)
|
||||
self.assertLen(model.get_updates_for(x2), 2)
|
||||
|
||||
# Test model-level reuse
|
||||
x3 = keras.layers.Input(shape=(10,))
|
||||
@ -413,7 +411,6 @@ class NormalizationLayersGraphModeOnlyTest(
|
||||
|
||||
self.assertLen(new_model.updates, 6)
|
||||
self.assertLen(model.updates, 6)
|
||||
self.assertLen(new_model.get_updates_for(x3), 2)
|
||||
new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||
new_model.train_on_batch(x, x)
|
||||
|
||||
|
@ -602,7 +602,7 @@ class RNNTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(layer.get_losses_for(None), [loss_2])
|
||||
self.assertEqual(layer.get_losses_for(x), [loss_1])
|
||||
|
||||
# Test `get_updates_for` and `updates`
|
||||
# Test `updates`
|
||||
cells = [keras.layers.LSTMCell(1),
|
||||
keras.layers.LSTMCell(1)]
|
||||
layer = keras.layers.RNN(cells)
|
||||
@ -618,8 +618,6 @@ class RNNTest(keras_parameterized.TestCase):
|
||||
cells[0].add_update(update_1, inputs=x)
|
||||
cells[0].add_update(update_2)
|
||||
self.assertEqual(len(layer.updates), 2)
|
||||
self.assertEqual(len(layer.get_updates_for(None)), 1)
|
||||
self.assertEqual(len(layer.get_updates_for(x)), 1)
|
||||
|
||||
def test_rnn_dynamic_trainability(self):
|
||||
layer_class = keras.layers.SimpleRNN
|
||||
|
@ -787,8 +787,6 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
|
||||
layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
|
||||
_ = layer(x)
|
||||
assert not layer.updates
|
||||
assert not layer.get_updates_for(None)
|
||||
assert not layer.get_updates_for(x)
|
||||
# TODO(b/128684069): Remove when Wrapper sublayers are __call__'d.
|
||||
with base_layer_utils.call_context().enter(layer, x, True, None):
|
||||
layer.forward_layer.add_update(x_reachable_update, inputs=x)
|
||||
@ -796,8 +794,6 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
|
||||
layer.backward_layer.add_update(x_reachable_update, inputs=x)
|
||||
layer.backward_layer.add_update(1, inputs=None)
|
||||
assert len(layer.updates) == 4
|
||||
assert len(layer.get_updates_for(None)) == 2
|
||||
assert len(layer.get_updates_for(x)) == 2
|
||||
|
||||
def test_Bidirectional_losses(self):
|
||||
x = keras.layers.Input(shape=(3, 2))
|
||||
|
@ -124,7 +124,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
||||
self.assertEqual(new_model._is_graph_network, model._is_graph_network)
|
||||
if input_shape:
|
||||
# update ops from batch norm needs to be included
|
||||
self.assertEqual(len(new_model.get_updates_for(new_model.inputs)), 2)
|
||||
self.assertGreaterEqual(len(new_model.updates), 2)
|
||||
|
||||
# On top of new tensor -- clone model should always have an InputLayer.
|
||||
input_a = keras.Input(shape=(4,))
|
||||
@ -173,7 +173,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
||||
|
||||
# With placeholder creation
|
||||
new_model = clone_fn(model)
|
||||
self.assertEqual(len(new_model.get_updates_for(new_model.inputs)), 2)
|
||||
self.assertGreaterEqual(len(new_model.updates), 2)
|
||||
new_model.compile(
|
||||
testing_utils.get_v2_optimizer('rmsprop'),
|
||||
'mse',
|
||||
@ -185,7 +185,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
||||
input_b = keras.Input(shape=(4,), name='b')
|
||||
new_model = keras.models.clone_model(
|
||||
model, input_tensors=[input_a, input_b])
|
||||
self.assertEqual(len(new_model.get_updates_for(new_model.inputs)), 2)
|
||||
self.assertLen(new_model.updates, 2)
|
||||
new_model.compile(
|
||||
testing_utils.get_v2_optimizer('rmsprop'),
|
||||
'mse',
|
||||
@ -199,7 +199,7 @@ class TestModelCloning(keras_parameterized.TestCase):
|
||||
input_a = keras.backend.variable(val_a)
|
||||
input_b = keras.backend.variable(val_b)
|
||||
new_model = clone_fn(model, input_tensors=[input_a, input_b])
|
||||
self.assertEqual(len(new_model.get_updates_for(new_model.inputs)), 2)
|
||||
self.assertGreaterEqual(len(new_model.updates), 2)
|
||||
new_model.compile(
|
||||
testing_utils.get_v2_optimizer('rmsprop'),
|
||||
'mse',
|
||||
|
@ -112,9 +112,9 @@ class VectorClassificationIntegrationTest(keras_parameterized.TestCase):
|
||||
optimizer=keras.optimizer_v2.adam.Adam(0.005),
|
||||
metrics=['acc'],
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
if not testing_utils.should_run_eagerly():
|
||||
self.assertEqual(len(model.get_losses_for(None)), 2)
|
||||
self.assertEqual(len(model.get_updates_for(x)), 2)
|
||||
self.assertLen(model.losses, 2)
|
||||
if not context.executing_eagerly():
|
||||
self.assertLen(model.get_updates_for(x), 2)
|
||||
history = model.fit(x_train, y_train, epochs=10, batch_size=10,
|
||||
validation_data=(x_train, y_train),
|
||||
verbose=2)
|
||||
|
@ -477,8 +477,6 @@ class ModelSubclassingTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(0, len(model.updates))
|
||||
else:
|
||||
self.assertEqual(2, len(model.updates))
|
||||
self.assertEqual(1, len(model.get_updates_for(None)))
|
||||
self.assertEqual(1, len(model.get_updates_for(x)))
|
||||
|
||||
|
||||
class GraphSpecificModelSubclassingTests(test.TestCase):
|
||||
@ -536,8 +534,8 @@ class GraphSpecificModelSubclassingTests(test.TestCase):
|
||||
|
||||
x = array_ops.ones(shape=[100, 784], dtype='float32')
|
||||
model(x)
|
||||
self.assertEqual(len(model.get_updates_for(x)), 2)
|
||||
self.assertEqual(len(model.get_losses_for(x)), 1)
|
||||
self.assertLen(model.updates, 2)
|
||||
self.assertLen(model.losses, 1)
|
||||
|
||||
# Case 2: placeholder-sequential nested in subclass.
|
||||
class TestModel2(keras.Model):
|
||||
|
Loading…
Reference in New Issue
Block a user