Remove conditional loss/graph tracing based on inputs in v2 add_update/get_updates_for API.

PiperOrigin-RevId: 307688443
Change-Id: Ia2059a0582b869b425d518eadabccdfd2c85ced7
This commit is contained in:
Pavithra Vijay 2020-04-21 15:05:07 -07:00 committed by TensorFlower Gardener
parent 983a8df746
commit ec3f58bc3c
8 changed files with 12 additions and 59 deletions

View File

@ -1484,9 +1484,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
dependent on `a` and some on `b`. This method automatically keeps track
of dependencies.
The `get_updates_for` method allows to retrieve the updates relevant to a
specific set of inputs.
This call is ignored when eager execution is enabled (in that case, variable
updates are run on the fly and thus do not need to be tracked for later
execution).
@ -1518,12 +1515,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
update()
return
if call_context.in_call:
relevant_inputs = call_context.inputs
else:
inbound_nodes = getattr(self, '_inbound_nodes', [])
relevant_inputs = [node.input_tensors for node in inbound_nodes]
def process_update(x):
"""Standardize update ops.
@ -1545,9 +1536,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
update = x.op
else:
update = ops.convert_to_tensor_v2(x)
reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, [update])
update._unconditional_update = update not in reachable
return update
updates = [process_update(x) for x in updates]
@ -1691,15 +1679,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
Returns:
List of update ops of the layer that depend on `inputs`.
"""
if inputs is None:
# Requesting unconditional updates.
return [u for u in self.updates if u._unconditional_update]
# Requesting input-conditional updates.
updates = [u for u in self.updates if not u._unconditional_update]
inputs = nest.flatten(inputs)
reachable = tf_utils.get_reachable_from_inputs(inputs, updates)
return [u for u in updates if u in reachable]
return self.updates
@doc_controls.do_not_doc_inheritable
def get_losses_for(self, inputs):

View File

@ -83,21 +83,14 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
_ = layer(x1)
self.assertEqual(len(layer.updates), 2)
self.assertEqual(len(layer.get_updates_for(x1)), 1)
self.assertEqual(len(layer.get_updates_for(None)), 1)
x2 = input_layer_lib.Input(shape=(1,))
y2 = layer(x2)
self.assertEqual(len(layer.updates), 3)
self.assertEqual(len(layer.get_updates_for(x1)), 1)
self.assertEqual(len(layer.get_updates_for(x2)), 1)
self.assertEqual(len(layer.get_updates_for(None)), 1)
network = network_lib.Network(x2, y2)
self.assertEqual(len(network.updates), 3)
self.assertEqual(len(network.get_updates_for(x2)), 1)
self.assertEqual(len(network.get_updates_for(None)), 1)
x3 = input_layer_lib.Input(shape=(1,))
_ = layer(x3)
@ -106,17 +99,12 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
x4 = input_layer_lib.Input(shape=(1,))
_ = network(x4)
self.assertEqual(len(network.updates), 5)
self.assertEqual(len(network.get_updates_for(x2)), 1)
self.assertEqual(len(network.get_updates_for(x4)), 1)
self.assertEqual(len(network.get_updates_for(None)), 1)
network.add_update(state_ops.assign_add(layer.a, [[1]]))
self.assertEqual(len(network.updates), 6)
self.assertEqual(len(network.get_updates_for(None)), 2)
network.add_update(state_ops.assign_add(layer.b, x4), inputs=True)
self.assertEqual(len(network.updates), 7)
self.assertEqual(len(network.get_updates_for(x4)), 2)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_get_updates_bn(self):
@ -125,8 +113,6 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
_ = layer(x1)
self.assertEqual(len(layer.updates), 2)
self.assertEqual(len(layer.get_updates_for(x1)), 2)
self.assertEqual(len(layer.get_updates_for(None)), 0)
def test_get_layer(self):
# create a simple network
@ -1572,7 +1558,6 @@ class NestedNetworkTest(keras_parameterized.TestCase):
output_shape = network.compute_output_shape([(None, 1), (None, 1)])
self.assertListEqual(output_shape.as_list(), [None, 1])
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_updates_with_direct_call(self):
inputs = input_layer_lib.Input(shape=(10,))
x = layers.BatchNormalization()(inputs)
@ -1582,8 +1567,7 @@ class NestedNetworkTest(keras_parameterized.TestCase):
ph = backend.placeholder(shape=(10, 10))
model(ph)
self.assertLen(model.get_updates_for(ph), 2)
self.assertLen(model.get_updates_for(None), 0)
self.assertLen(model.updates, 4)
def test_dict_mapping_input(self):

View File

@ -403,8 +403,6 @@ class NormalizationLayersGraphModeOnlyTest(
model.train_on_batch(x, x)
self.assertLen(bn.updates, 4)
self.assertLen(bn.get_updates_for(x1), 2)
self.assertLen(model.get_updates_for(x2), 2)
# Test model-level reuse
x3 = keras.layers.Input(shape=(10,))
@ -413,7 +411,6 @@ class NormalizationLayersGraphModeOnlyTest(
self.assertLen(new_model.updates, 6)
self.assertLen(model.updates, 6)
self.assertLen(new_model.get_updates_for(x3), 2)
new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
new_model.train_on_batch(x, x)

View File

@ -602,7 +602,7 @@ class RNNTest(keras_parameterized.TestCase):
self.assertEqual(layer.get_losses_for(None), [loss_2])
self.assertEqual(layer.get_losses_for(x), [loss_1])
# Test `get_updates_for` and `updates`
# Test `updates`
cells = [keras.layers.LSTMCell(1),
keras.layers.LSTMCell(1)]
layer = keras.layers.RNN(cells)
@ -618,8 +618,6 @@ class RNNTest(keras_parameterized.TestCase):
cells[0].add_update(update_1, inputs=x)
cells[0].add_update(update_2)
self.assertEqual(len(layer.updates), 2)
self.assertEqual(len(layer.get_updates_for(None)), 1)
self.assertEqual(len(layer.get_updates_for(x)), 1)
def test_rnn_dynamic_trainability(self):
layer_class = keras.layers.SimpleRNN

View File

@ -787,8 +787,6 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
_ = layer(x)
assert not layer.updates
assert not layer.get_updates_for(None)
assert not layer.get_updates_for(x)
# TODO(b/128684069): Remove when Wrapper sublayers are __call__'d.
with base_layer_utils.call_context().enter(layer, x, True, None):
layer.forward_layer.add_update(x_reachable_update, inputs=x)
@ -796,8 +794,6 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
layer.backward_layer.add_update(x_reachable_update, inputs=x)
layer.backward_layer.add_update(1, inputs=None)
assert len(layer.updates) == 4
assert len(layer.get_updates_for(None)) == 2
assert len(layer.get_updates_for(x)) == 2
def test_Bidirectional_losses(self):
x = keras.layers.Input(shape=(3, 2))

View File

@ -124,7 +124,7 @@ class TestModelCloning(keras_parameterized.TestCase):
self.assertEqual(new_model._is_graph_network, model._is_graph_network)
if input_shape:
# update ops from batch norm needs to be included
self.assertEqual(len(new_model.get_updates_for(new_model.inputs)), 2)
self.assertGreaterEqual(len(new_model.updates), 2)
# On top of new tensor -- clone model should always have an InputLayer.
input_a = keras.Input(shape=(4,))
@ -173,7 +173,7 @@ class TestModelCloning(keras_parameterized.TestCase):
# With placeholder creation
new_model = clone_fn(model)
self.assertEqual(len(new_model.get_updates_for(new_model.inputs)), 2)
self.assertGreaterEqual(len(new_model.updates), 2)
new_model.compile(
testing_utils.get_v2_optimizer('rmsprop'),
'mse',
@ -185,7 +185,7 @@ class TestModelCloning(keras_parameterized.TestCase):
input_b = keras.Input(shape=(4,), name='b')
new_model = keras.models.clone_model(
model, input_tensors=[input_a, input_b])
self.assertEqual(len(new_model.get_updates_for(new_model.inputs)), 2)
self.assertLen(new_model.updates, 2)
new_model.compile(
testing_utils.get_v2_optimizer('rmsprop'),
'mse',
@ -199,7 +199,7 @@ class TestModelCloning(keras_parameterized.TestCase):
input_a = keras.backend.variable(val_a)
input_b = keras.backend.variable(val_b)
new_model = clone_fn(model, input_tensors=[input_a, input_b])
self.assertEqual(len(new_model.get_updates_for(new_model.inputs)), 2)
self.assertGreaterEqual(len(new_model.updates), 2)
new_model.compile(
testing_utils.get_v2_optimizer('rmsprop'),
'mse',

View File

@ -112,9 +112,9 @@ class VectorClassificationIntegrationTest(keras_parameterized.TestCase):
optimizer=keras.optimizer_v2.adam.Adam(0.005),
metrics=['acc'],
run_eagerly=testing_utils.should_run_eagerly())
if not testing_utils.should_run_eagerly():
self.assertEqual(len(model.get_losses_for(None)), 2)
self.assertEqual(len(model.get_updates_for(x)), 2)
self.assertLen(model.losses, 2)
if not context.executing_eagerly():
self.assertLen(model.get_updates_for(x), 2)
history = model.fit(x_train, y_train, epochs=10, batch_size=10,
validation_data=(x_train, y_train),
verbose=2)

View File

@ -477,8 +477,6 @@ class ModelSubclassingTest(keras_parameterized.TestCase):
self.assertEqual(0, len(model.updates))
else:
self.assertEqual(2, len(model.updates))
self.assertEqual(1, len(model.get_updates_for(None)))
self.assertEqual(1, len(model.get_updates_for(x)))
class GraphSpecificModelSubclassingTests(test.TestCase):
@ -536,8 +534,8 @@ class GraphSpecificModelSubclassingTests(test.TestCase):
x = array_ops.ones(shape=[100, 784], dtype='float32')
model(x)
self.assertEqual(len(model.get_updates_for(x)), 2)
self.assertEqual(len(model.get_losses_for(x)), 1)
self.assertLen(model.updates, 2)
self.assertLen(model.losses, 1)
# Case 2: placeholder-sequential nested in subclass.
class TestModel2(keras.Model):