Switch _set_inputs
to use full __call__
.
PiperOrigin-RevId: 238688012
This commit is contained in:
parent
d1a9e46b4a
commit
b57c7d71ef
@ -1164,6 +1164,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
|
||||
model = self._get_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = keras.callbacks.TensorBoard(self.logdir, histogram_freq=1)
|
||||
model_type = testing_utils.get_model_type()
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
@ -1182,7 +1183,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
self._strip_layer_names(summary_file.histograms),
|
||||
self._strip_layer_names(summary_file.histograms, model_type),
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag='bias_0'),
|
||||
_ObservedSummary(logdir=self.train_dir, tag='kernel_0'),
|
||||
@ -1194,6 +1195,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = keras.callbacks.TensorBoard(
|
||||
self.logdir, histogram_freq=1, write_images=True)
|
||||
model_type = testing_utils.get_model_type()
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
@ -1212,14 +1214,14 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
self._strip_layer_names(summary_file.histograms),
|
||||
self._strip_layer_names(summary_file.histograms, model_type),
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag='bias_0'),
|
||||
_ObservedSummary(logdir=self.train_dir, tag='kernel_0'),
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
self._strip_layer_names(summary_file.images),
|
||||
self._strip_layer_names(summary_file.images, model_type),
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag='bias_0/image/0'),
|
||||
_ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/0'),
|
||||
@ -1228,7 +1230,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def _strip_layer_names(self, summaries):
|
||||
def _strip_layer_names(self, summaries, model_type):
|
||||
"""Deduplicate summary names modulo layer prefix.
|
||||
|
||||
This removes the first slash-component of each tag name: for
|
||||
@ -1236,6 +1238,7 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
|
||||
|
||||
Args:
|
||||
summaries: A `set` of `_ObservedSummary` values.
|
||||
model_type: The model type currently being tested.
|
||||
|
||||
Returns:
|
||||
A new `set` of `_ObservedSummary` values with layer prefixes
|
||||
@ -1245,7 +1248,8 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
|
||||
for summary in summaries:
|
||||
if '/' not in summary.tag:
|
||||
raise ValueError('tag has no layer name: %r' % summary.tag)
|
||||
new_tag = summary.tag.split('/', 1)[1]
|
||||
start_from = 2 if 'subclass' in model_type else 1
|
||||
new_tag = '/'.join(summary.tag.split('/')[start_from:])
|
||||
result.add(summary._replace(tag=new_tag))
|
||||
return result
|
||||
|
||||
|
@ -714,9 +714,7 @@ class Layer(trackable.Trackable):
|
||||
|
||||
@property
|
||||
def updates(self):
|
||||
if not self.trainable and not self.stateful:
|
||||
return []
|
||||
return self._updates + self._gather_children_attribute('updates')
|
||||
return self._get_unfiltered_updates(check_trainable=True)
|
||||
|
||||
@property
|
||||
def losses(self):
|
||||
@ -967,13 +965,15 @@ class Layer(trackable.Trackable):
|
||||
|
||||
if inputs is None:
|
||||
# Requesting unconditional updates.
|
||||
return [x for x in self._unfiltered_updates if x._unconditional_update] # pylint: disable=protected-access
|
||||
return [
|
||||
x for x in self._get_unfiltered_updates() if x._unconditional_update # pylint: disable=protected-access
|
||||
]
|
||||
|
||||
# Requesting input-conditional updates.
|
||||
inputs = nest.flatten(inputs)
|
||||
reachable = tf_utils.get_reachable_from_inputs(inputs,
|
||||
self._unfiltered_updates)
|
||||
return [u for u in self._unfiltered_updates if u in reachable] # pylint: disable=protected-access
|
||||
reachable = tf_utils.get_reachable_from_inputs(
|
||||
inputs, self._get_unfiltered_updates())
|
||||
return [u for u in self._get_unfiltered_updates() if u in reachable] # pylint: disable=protected-access
|
||||
|
||||
def get_losses_for(self, inputs):
|
||||
"""Retrieves losses relevant to a specific set of inputs.
|
||||
@ -1847,10 +1847,10 @@ class Layer(trackable.Trackable):
|
||||
def _is_layer(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def _unfiltered_updates(self):
|
||||
# Overridden in `Network`.
|
||||
return self.updates
|
||||
def _get_unfiltered_updates(self, check_trainable=True):
|
||||
if check_trainable and not self.trainable and not self.stateful:
|
||||
return []
|
||||
return self._updates + self._gather_children_attribute('updates')
|
||||
|
||||
|
||||
class Node(object):
|
||||
|
@ -482,8 +482,10 @@ class AutoAddUpdates(object):
|
||||
if is_stateful_op and op.type != 'ReadVariableOp':
|
||||
new_stateful_ops.add(op)
|
||||
|
||||
explicit_updates = set(
|
||||
[u for u in self.layer._unfiltered_updates if not isinstance(u, tuple)])
|
||||
explicit_updates = set([
|
||||
u for u in self.layer._get_unfiltered_updates(check_trainable=False)
|
||||
if not isinstance(u, tuple)
|
||||
])
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Don't add updates that will already be run by virtue of being consumed by
|
||||
@ -542,4 +544,3 @@ def autocast_context_manager(input_list, should_cast):
|
||||
var_read_dtype = _get_var_read_dtype(input_list, should_cast)
|
||||
return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access
|
||||
var_read_dtype)
|
||||
|
||||
|
@ -521,11 +521,12 @@ class Network(base_layer.Layer):
|
||||
return layer
|
||||
raise ValueError('No such layer: ' + name)
|
||||
|
||||
@property
|
||||
def _unfiltered_updates(self):
|
||||
def _get_unfiltered_updates(self, check_trainable=True):
|
||||
if check_trainable and not self.trainable and not self.stateful:
|
||||
return []
|
||||
updates = []
|
||||
for layer in self.layers:
|
||||
updates += layer._unfiltered_updates
|
||||
updates += layer._get_unfiltered_updates(check_trainable=check_trainable)
|
||||
updates += list(self._updates)
|
||||
return updates
|
||||
|
||||
@ -605,10 +606,8 @@ class Network(base_layer.Layer):
|
||||
Returns:
|
||||
A list of update ops.
|
||||
"""
|
||||
if not self.trainable and not self.stateful:
|
||||
return []
|
||||
|
||||
updates = self._unfiltered_updates
|
||||
updates = self._get_unfiltered_updates(check_trainable=True)
|
||||
|
||||
# `updates` might contain irrelevant updates, so it needs to be filtered
|
||||
# with respect to inputs the model has been called on.
|
||||
|
@ -2612,7 +2612,7 @@ class Model(network.Network):
|
||||
'However we received `validation_data=%s`' % validation_data)
|
||||
return val_x, val_y, val_sample_weight
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
# TODO(omalleyt): Consider changing to a more descriptive function name.
|
||||
def _set_inputs(self, inputs, outputs=None, training=None):
|
||||
"""Set model's input and output specs based on the input data received.
|
||||
|
||||
@ -2639,6 +2639,22 @@ class Model(network.Network):
|
||||
ValueError: If dict inputs are passed to a Sequential Model where the
|
||||
first layer isn't FeatureLayer.
|
||||
"""
|
||||
inputs = self._set_input_attrs(inputs)
|
||||
|
||||
if outputs is None:
|
||||
kwargs = {'training': training} if self._expects_training_arg else {}
|
||||
try:
|
||||
outputs = self(inputs, **kwargs)
|
||||
except NotImplementedError:
|
||||
# This Model or a submodel is dynamic and hasn't overridden
|
||||
# `compute_output_shape`.
|
||||
outputs = None
|
||||
|
||||
self._set_output_attrs(outputs)
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _set_input_attrs(self, inputs):
|
||||
"""Sets attributes related to the inputs of the Model."""
|
||||
if self.inputs:
|
||||
raise ValueError('Model inputs are already set.')
|
||||
|
||||
@ -2675,33 +2691,11 @@ class Model(network.Network):
|
||||
self._feed_inputs.append(v)
|
||||
self._feed_input_shapes.append(K.int_shape(v))
|
||||
|
||||
# TODO(fchollet): consider calling `_maybe_build` before calling the model.
|
||||
if outputs is None:
|
||||
if not self._dynamic:
|
||||
# The network may include dynamic layers but its `call`
|
||||
# itself isn't dynamic.
|
||||
# Obtain symbolic outputs by calling the model.
|
||||
with K.get_graph().as_default():
|
||||
contains_symbolic_tensors = getattr(
|
||||
self, '_contains_symbolic_tensors', False)
|
||||
if self._expects_training_arg:
|
||||
outputs = self.call(inputs, training=training)
|
||||
else:
|
||||
outputs = self.call(inputs)
|
||||
# Reset to the previously saved value. If `call()` had `add_metric`
|
||||
# or `add_loss`, then `_contains_symbolic_tensors` will have been set
|
||||
# to True since we are not in `__call__` context. Hence we are
|
||||
# resetting to the old value here.
|
||||
self._contains_symbolic_tensors = contains_symbolic_tensors
|
||||
else:
|
||||
# Case: network's `call` is dynamic.
|
||||
try:
|
||||
outputs = self._symbolic_call(inputs)
|
||||
except NotImplementedError:
|
||||
# Static shape inference was not implemented for this dynamic net.
|
||||
# Do not specify symbolic outputs.
|
||||
outputs = None
|
||||
return inputs
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _set_output_attrs(self, outputs):
|
||||
"""Sets attributes related to the outputs of the Model."""
|
||||
outputs = nest.flatten(outputs)
|
||||
self.outputs = outputs
|
||||
self.output_names = training_utils.generic_output_names(outputs)
|
||||
|
@ -187,8 +187,8 @@ def get_nested_model_3(input_dim, num_classes):
|
||||
return keras.Model(inputs, outputs, name='nested_model_3')
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ModelSubclassingTest(test.TestCase):
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class ModelSubclassingTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_custom_build(self):
|
||||
class DummyModel(keras.Model):
|
||||
@ -210,6 +210,26 @@ class ModelSubclassingTest(test.TestCase):
|
||||
self.assertTrue(test_model.uses_custom_build, 'Model should use user '
|
||||
'defined build when called.')
|
||||
|
||||
def test_custom_build_with_fit(self):
|
||||
|
||||
class DummyModel(keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(DummyModel, self).__init__()
|
||||
self.layer1 = keras.layers.Dense(10, activation='relu')
|
||||
|
||||
def build(self, input_shape):
|
||||
self.layer2 = keras.layers.Dense(1, activation='relu')
|
||||
|
||||
def call(self, inputs):
|
||||
return self.layer2(self.layer1(inputs))
|
||||
|
||||
model = DummyModel()
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2, epochs=2)
|
||||
self.assertLen(model.layers, 2)
|
||||
self.assertLen(model.trainable_variables, 4)
|
||||
|
||||
def test_invalid_input_shape_build(self):
|
||||
num_classes = 2
|
||||
input_dim = 50
|
||||
|
Loading…
x
Reference in New Issue
Block a user