Make it possible for a subclassed model to be compiled and trained with no

external loss.

PiperOrigin-RevId: 248035588
This commit is contained in:
Francois Chollet 2019-05-13 16:35:45 -07:00 committed by TensorFlower Gardener
parent 20359d70e4
commit 09ca187c09
4 changed files with 72 additions and 41 deletions

View File

@ -592,6 +592,7 @@ class Model(network.Network):
epochs = kwargs.pop('nb_epoch') epochs = kwargs.pop('nb_epoch')
if kwargs: if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
self._assert_compile_was_called()
# Case 1: distribution strategy. # Case 1: distribution strategy.
if self._distribution_strategy: if self._distribution_strategy:
@ -862,6 +863,8 @@ class Model(network.Network):
ValueError: in case of invalid arguments. ValueError: in case of invalid arguments.
""" """
_keras_api_gauge.get_cell('evaluate').set(True) _keras_api_gauge.get_cell('evaluate').set(True)
self._assert_compile_was_called()
# Case 1: distribution strategy. # Case 1: distribution strategy.
if self._distribution_strategy: if self._distribution_strategy:
if K.in_multi_worker_mode(): if K.in_multi_worker_mode():
@ -1133,6 +1136,7 @@ class Model(network.Network):
Raises: Raises:
ValueError: In case of invalid user-provided arguments. ValueError: In case of invalid user-provided arguments.
""" """
self._assert_compile_was_called()
# If at this point we are in the replica context, then it is okay to execute # If at this point we are in the replica context, then it is okay to execute
# the Eager code path. The expected way to get here is to call `fit` that # the Eager code path. The expected way to get here is to call `fit` that
# calls `train_on_batch` on each replica. # calls `train_on_batch` on each replica.
@ -1213,6 +1217,7 @@ class Model(network.Network):
Raises: Raises:
ValueError: In case of invalid user-provided arguments. ValueError: In case of invalid user-provided arguments.
""" """
self._assert_compile_was_called()
if (self._distribution_strategy and if (self._distribution_strategy and
distribution_strategy_context.in_cross_replica_context()): distribution_strategy_context.in_cross_replica_context()):
raise NotImplementedError('`test_on_batch` is not supported for models ' raise NotImplementedError('`test_on_batch` is not supported for models '
@ -2191,8 +2196,6 @@ class Model(network.Network):
metrics_tensors = [ metrics_tensors = [
self._all_metrics_tensors[m] for m in self.metrics_names[1:] self._all_metrics_tensors[m] for m in self.metrics_names[1:]
] ]
if not self._is_compiled:
raise RuntimeError('You must compile your model before using it.')
self._check_trainable_weights_consistency() self._check_trainable_weights_consistency()
# If we have re-compiled the loss/weighted metric sub-graphs then create # If we have re-compiled the loss/weighted metric sub-graphs then create
# train function even if one exists already. This is because # train function even if one exists already. This is because
@ -2229,8 +2232,6 @@ class Model(network.Network):
metrics_tensors = [ metrics_tensors = [
self._all_metrics_tensors[m] for m in self.metrics_names[1:] self._all_metrics_tensors[m] for m in self.metrics_names[1:]
] ]
if not self._is_compiled:
raise RuntimeError('You must compile your model before using it.')
# If we have re-compiled the loss/weighted metric sub-graphs then create # If we have re-compiled the loss/weighted metric sub-graphs then create
# test function even if one exists already. This is because # test function even if one exists already. This is because
# `_feed_sample_weights` list has been updated on re-copmpile. # `_feed_sample_weights` list has been updated on re-copmpile.
@ -2555,13 +2556,9 @@ class Model(network.Network):
y_input = y y_input = y
dict_inputs = isinstance(self.inputs, dict) dict_inputs = isinstance(self.inputs, dict)
if y_input is not None: if not self._is_compiled and self.optimizer:
if not self.optimizer: # On-the-fly compilation of the model.
raise RuntimeError('You must compile a model before ' if y_input is not None:
'training/testing. '
'Use `model.compile(optimizer, loss)`.')
if not self._is_compiled:
# On-the-fly compilation of the model.
# We need to use `y` to set the model targets. # We need to use `y` to set the model targets.
if training_utils.has_tensors(y_input): if training_utils.has_tensors(y_input):
y_input = training_utils.cast_if_floating_dtype(y_input) y_input = training_utils.cast_if_floating_dtype(y_input)
@ -2582,31 +2579,34 @@ class Model(network.Network):
'You passed: y=' + str(y)) 'You passed: y=' + str(y))
all_inputs.append(y_input) all_inputs.append(y_input)
# Typecheck that all inputs are *either* value *or* symbolic. # Typecheck that all inputs are *either* value *or* symbolic.
# TODO(fchollet): this check could be removed in Eager mode? # TODO(fchollet): this check could be removed in Eager mode?
if any(tensor_util.is_tensor(v) for v in all_inputs): if any(tensor_util.is_tensor(v) for v in all_inputs):
if not all(tensor_util.is_tensor(v) for v in all_inputs): if not all(tensor_util.is_tensor(v) for v in all_inputs):
raise ValueError('Do not pass inputs that mix Numpy arrays and ' raise ValueError('Do not pass inputs that mix Numpy arrays and '
'TensorFlow tensors. ' 'TensorFlow tensors. '
'You passed: x=' + str(x) + '; y=' + str(y)) 'You passed: x=' + str(x) + '; y=' + str(y))
if is_dataset or context.executing_eagerly(): if is_dataset or context.executing_eagerly():
target_tensors = None target_tensors = None
else: else:
# Handle target tensors if any passed. # Handle target tensors if any passed.
if y_input is not None:
if not isinstance(y_input, (list, tuple)): if not isinstance(y_input, (list, tuple)):
y_input = [y_input] y_input = [y_input]
target_tensors = [v for v in y_input if _is_symbolic_tensor(v)] target_tensors = [v for v in y_input if _is_symbolic_tensor(v)]
is_compile_called = True else:
self.compile( target_tensors = None
optimizer=self.optimizer, is_compile_called = True
loss=self.loss, self.compile(
metrics=self._compile_metrics, optimizer=self.optimizer,
weighted_metrics=self._compile_weighted_metrics, loss=self.loss,
loss_weights=self.loss_weights, metrics=self._compile_metrics,
target_tensors=target_tensors, weighted_metrics=self._compile_weighted_metrics,
run_eagerly=self.run_eagerly, loss_weights=self.loss_weights,
cloning=self._cloning) target_tensors=target_tensors,
run_eagerly=self.run_eagerly,
cloning=self._cloning)
# In graph mode, if we had just set inputs and targets as symbolic tensors # In graph mode, if we had just set inputs and targets as symbolic tensors
# by invoking build and compile on the model respectively, we do not have to # by invoking build and compile on the model respectively, we do not have to
@ -2921,6 +2921,16 @@ class Model(network.Network):
return int(self._ckpt_saved_epoch) + 1 return int(self._ckpt_saved_epoch) + 1
return initial_epoch return initial_epoch
def _assert_compile_was_called(self):
# Checks whether `compile` has been called. If it has been called,
# then the optimizer is set. This is different from whether the
# model is compiled
# (i.e. whether the model is built and its inputs/outputs are set).
if not self.optimizer:
raise RuntimeError('You must compile your model before '
'training/testing. '
'Use `model.compile(optimizer, loss)`.')
class DistributedCallbackModel(Model): class DistributedCallbackModel(Model):
"""Model that is used for callbacks with tf.distribute.Strategy.""" """Model that is used for callbacks with tf.distribute.Strategy."""

View File

@ -339,12 +339,14 @@ def test_on_batch(model,
if len(inputs) and tensor_util.is_tensor(inputs[0]): if len(inputs) and tensor_util.is_tensor(inputs[0]):
inputs = training_utils.cast_if_floating_to_model_input_dtypes(inputs, inputs = training_utils.cast_if_floating_to_model_input_dtypes(inputs,
model) model)
targets = training_utils.cast_if_floating_dtype(targets) if targets:
targets = training_utils.cast_if_floating_dtype(targets)
else: else:
inputs = training_utils.cast_if_floating_to_model_input_dtypes( inputs = training_utils.cast_if_floating_to_model_input_dtypes(
[ops.convert_to_tensor(val) for val in inputs], model) [ops.convert_to_tensor(val) for val in inputs], model)
targets = training_utils.cast_if_floating_dtype( if targets:
[ops.convert_to_tensor(val) for val in targets]) targets = training_utils.cast_if_floating_dtype(
[ops.convert_to_tensor(val) for val in targets])
if sample_weights: if sample_weights:
sample_weights = [ sample_weights = [
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val)) training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))

View File

@ -1201,8 +1201,7 @@ class TrainingTest(keras_parameterized.TestCase):
inputs = keras.Input(shape=(1,)) inputs = keras.Input(shape=(1,))
outputs = keras.layers.Dense(1)(inputs) outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs) model = keras.Model(inputs, outputs)
with self.assertRaisesRegex( with self.assertRaisesRegex(RuntimeError, 'must compile your model'):
RuntimeError, 'You must compile a model before training/testing.'):
model.fit(np.ones((1, 1)), np.ones((1, 1))) model.fit(np.ones((1, 1)), np.ones((1, 1)))
class MyModel(keras.Model): class MyModel(keras.Model):
@ -1212,10 +1211,7 @@ class TrainingTest(keras_parameterized.TestCase):
return x return x
model = MyModel() model = MyModel()
model.compile(keras.optimizers.Adam(1e-3)) with self.assertRaisesRegex(RuntimeError, 'must compile your model'):
with self.assertRaisesRegex(RuntimeError,
'You must compile your model before using it.'):
model.fit(np.random.random((32, 1)), epochs=2) model.fit(np.random.random((32, 1)), epochs=2)

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
@ -992,6 +993,28 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
loss = model.train_on_batch(x, y) loss = model.train_on_batch(x, y)
self.assertGreater(loss, 0.1) self.assertGreater(loss, 0.1)
def test_no_loss_in_compile(self):
class InternalLossModel(keras.Model):
def __init__(self):
super(InternalLossModel, self).__init__()
self.dense = keras.layers.Dense(1)
def call(self, inputs):
out = self.dense(inputs)
self.add_loss(math_ops.reduce_sum(out))
return out
model = InternalLossModel()
x = np.ones((10, 10))
model.predict(x)
model.compile(
optimizer='rmsprop',
run_eagerly=testing_utils.should_run_eagerly())
model.fit(x)
model.evaluate(x)
class GraphSpecificModelSubclassingTests(test.TestCase): class GraphSpecificModelSubclassingTests(test.TestCase):