Switch _set_inputs to use full __call__.

PiperOrigin-RevId: 238688012
This commit is contained in:
Thomas O'Malley 2019-03-15 12:18:44 -07:00 committed by TensorFlower Gardener
parent d1a9e46b4a
commit b57c7d71ef
6 changed files with 72 additions and 54 deletions

View File

@ -1164,6 +1164,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(self.logdir, histogram_freq=1)
model_type = testing_utils.get_model_type()
model.fit(
x,
@ -1182,7 +1183,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
},
)
self.assertEqual(
self._strip_layer_names(summary_file.histograms),
self._strip_layer_names(summary_file.histograms, model_type),
{
_ObservedSummary(logdir=self.train_dir, tag='bias_0'),
_ObservedSummary(logdir=self.train_dir, tag='kernel_0'),
@ -1194,6 +1195,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, write_images=True)
model_type = testing_utils.get_model_type()
model.fit(
x,
@ -1212,14 +1214,14 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
},
)
self.assertEqual(
self._strip_layer_names(summary_file.histograms),
self._strip_layer_names(summary_file.histograms, model_type),
{
_ObservedSummary(logdir=self.train_dir, tag='bias_0'),
_ObservedSummary(logdir=self.train_dir, tag='kernel_0'),
},
)
self.assertEqual(
self._strip_layer_names(summary_file.images),
self._strip_layer_names(summary_file.images, model_type),
{
_ObservedSummary(logdir=self.train_dir, tag='bias_0/image/0'),
_ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/0'),
@ -1228,7 +1230,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
},
)
def _strip_layer_names(self, summaries):
def _strip_layer_names(self, summaries, model_type):
"""Deduplicate summary names modulo layer prefix.
This removes the first slash-component of each tag name: for
@ -1236,6 +1238,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
Args:
summaries: A `set` of `_ObservedSummary` values.
model_type: The model type currently being tested.
Returns:
A new `set` of `_ObservedSummary` values with layer prefixes
@ -1245,7 +1248,8 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
for summary in summaries:
if '/' not in summary.tag:
raise ValueError('tag has no layer name: %r' % summary.tag)
new_tag = summary.tag.split('/', 1)[1]
start_from = 2 if 'subclass' in model_type else 1
new_tag = '/'.join(summary.tag.split('/')[start_from:])
result.add(summary._replace(tag=new_tag))
return result

View File

@ -714,9 +714,7 @@ class Layer(trackable.Trackable):
@property
def updates(self):
if not self.trainable and not self.stateful:
return []
return self._updates + self._gather_children_attribute('updates')
return self._get_unfiltered_updates(check_trainable=True)
@property
def losses(self):
@ -967,13 +965,15 @@ class Layer(trackable.Trackable):
if inputs is None:
# Requesting unconditional updates.
return [x for x in self._unfiltered_updates if x._unconditional_update] # pylint: disable=protected-access
return [
x for x in self._get_unfiltered_updates() if x._unconditional_update # pylint: disable=protected-access
]
# Requesting input-conditional updates.
inputs = nest.flatten(inputs)
reachable = tf_utils.get_reachable_from_inputs(inputs,
self._unfiltered_updates)
return [u for u in self._unfiltered_updates if u in reachable] # pylint: disable=protected-access
reachable = tf_utils.get_reachable_from_inputs(
inputs, self._get_unfiltered_updates())
return [u for u in self._get_unfiltered_updates() if u in reachable] # pylint: disable=protected-access
def get_losses_for(self, inputs):
"""Retrieves losses relevant to a specific set of inputs.
@ -1847,10 +1847,10 @@ class Layer(trackable.Trackable):
def _is_layer(self):
return True
@property
def _unfiltered_updates(self):
# Overridden in `Network`.
return self.updates
def _get_unfiltered_updates(self, check_trainable=True):
if check_trainable and not self.trainable and not self.stateful:
return []
return self._updates + self._gather_children_attribute('updates')
class Node(object):

View File

@ -482,8 +482,10 @@ class AutoAddUpdates(object):
if is_stateful_op and op.type != 'ReadVariableOp':
new_stateful_ops.add(op)
explicit_updates = set(
[u for u in self.layer._unfiltered_updates if not isinstance(u, tuple)])
explicit_updates = set([
u for u in self.layer._get_unfiltered_updates(check_trainable=False)
if not isinstance(u, tuple)
])
# pylint: enable=protected-access
# Don't add updates that will already be run by virtue of being consumed by
@ -542,4 +544,3 @@ def autocast_context_manager(input_list, should_cast):
var_read_dtype = _get_var_read_dtype(input_list, should_cast)
return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access
var_read_dtype)

View File

@ -521,11 +521,12 @@ class Network(base_layer.Layer):
return layer
raise ValueError('No such layer: ' + name)
@property
def _unfiltered_updates(self):
def _get_unfiltered_updates(self, check_trainable=True):
if check_trainable and not self.trainable and not self.stateful:
return []
updates = []
for layer in self.layers:
updates += layer._unfiltered_updates
updates += layer._get_unfiltered_updates(check_trainable=check_trainable)
updates += list(self._updates)
return updates
@ -605,10 +606,8 @@ class Network(base_layer.Layer):
Returns:
A list of update ops.
"""
if not self.trainable and not self.stateful:
return []
updates = self._unfiltered_updates
updates = self._get_unfiltered_updates(check_trainable=True)
# `updates` might contain irrelevant updates, so it needs to be filtered
# with respect to inputs the model has been called on.

View File

@ -2612,7 +2612,7 @@ class Model(network.Network):
'However we received `validation_data=%s`' % validation_data)
return val_x, val_y, val_sample_weight
@trackable.no_automatic_dependency_tracking
# TODO(omalleyt): Consider changing to a more descriptive function name.
def _set_inputs(self, inputs, outputs=None, training=None):
"""Set model's input and output specs based on the input data received.
@ -2639,6 +2639,22 @@ class Model(network.Network):
ValueError: If dict inputs are passed to a Sequential Model where the
first layer isn't FeatureLayer.
"""
inputs = self._set_input_attrs(inputs)
if outputs is None:
kwargs = {'training': training} if self._expects_training_arg else {}
try:
outputs = self(inputs, **kwargs)
except NotImplementedError:
# This Model or a submodel is dynamic and hasn't overridden
# `compute_output_shape`.
outputs = None
self._set_output_attrs(outputs)
@trackable.no_automatic_dependency_tracking
def _set_input_attrs(self, inputs):
"""Sets attributes related to the inputs of the Model."""
if self.inputs:
raise ValueError('Model inputs are already set.')
@ -2675,33 +2691,11 @@ class Model(network.Network):
self._feed_inputs.append(v)
self._feed_input_shapes.append(K.int_shape(v))
# TODO(fchollet): consider calling `_maybe_build` before calling the model.
if outputs is None:
if not self._dynamic:
# The network may include dynamic layers but its `call`
# itself isn't dynamic.
# Obtain symbolic outputs by calling the model.
with K.get_graph().as_default():
contains_symbolic_tensors = getattr(
self, '_contains_symbolic_tensors', False)
if self._expects_training_arg:
outputs = self.call(inputs, training=training)
else:
outputs = self.call(inputs)
# Reset to the previously saved value. If `call()` had `add_metric`
# or `add_loss`, then `_contains_symbolic_tensors` will have been set
# to True since we are not in `__call__` context. Hence we are
# resetting to the old value here.
self._contains_symbolic_tensors = contains_symbolic_tensors
else:
# Case: network's `call` is dynamic.
try:
outputs = self._symbolic_call(inputs)
except NotImplementedError:
# Static shape inference was not implemented for this dynamic net.
# Do not specify symbolic outputs.
outputs = None
return inputs
@trackable.no_automatic_dependency_tracking
def _set_output_attrs(self, outputs):
"""Sets attributes related to the outputs of the Model."""
outputs = nest.flatten(outputs)
self.outputs = outputs
self.output_names = training_utils.generic_output_names(outputs)

View File

@ -187,8 +187,8 @@ def get_nested_model_3(input_dim, num_classes):
return keras.Model(inputs, outputs, name='nested_model_3')
@test_util.run_all_in_graph_and_eager_modes
class ModelSubclassingTest(test.TestCase):
@keras_parameterized.run_all_keras_modes
class ModelSubclassingTest(keras_parameterized.TestCase):
def test_custom_build(self):
class DummyModel(keras.Model):
@ -210,6 +210,26 @@ class ModelSubclassingTest(test.TestCase):
self.assertTrue(test_model.uses_custom_build, 'Model should use user '
'defined build when called.')
def test_custom_build_with_fit(self):
class DummyModel(keras.Model):
def __init__(self):
super(DummyModel, self).__init__()
self.layer1 = keras.layers.Dense(10, activation='relu')
def build(self, input_shape):
self.layer2 = keras.layers.Dense(1, activation='relu')
def call(self, inputs):
return self.layer2(self.layer1(inputs))
model = DummyModel()
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2, epochs=2)
self.assertLen(model.layers, 2)
self.assertLen(model.trainable_variables, 4)
def test_invalid_input_shape_build(self):
num_classes = 2
input_dim = 50