Make it possible for a subclassed model to be compiled and trained with no

external loss.

PiperOrigin-RevId: 248035588
This commit is contained in:
Francois Chollet 2019-05-13 16:35:45 -07:00 committed by TensorFlower Gardener
parent 20359d70e4
commit 09ca187c09
4 changed files with 72 additions and 41 deletions

View File

@ -592,6 +592,7 @@ class Model(network.Network):
epochs = kwargs.pop('nb_epoch')
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
self._assert_compile_was_called()
# Case 1: distribution strategy.
if self._distribution_strategy:
@ -862,6 +863,8 @@ class Model(network.Network):
ValueError: in case of invalid arguments.
"""
_keras_api_gauge.get_cell('evaluate').set(True)
self._assert_compile_was_called()
# Case 1: distribution strategy.
if self._distribution_strategy:
if K.in_multi_worker_mode():
@ -1133,6 +1136,7 @@ class Model(network.Network):
Raises:
ValueError: In case of invalid user-provided arguments.
"""
self._assert_compile_was_called()
# If at this point we are in the replica context, then it is okay to execute
# the Eager code path. The expected way to get here is to call `fit` that
# calls `train_on_batch` on each replica.
@ -1213,6 +1217,7 @@ class Model(network.Network):
Raises:
ValueError: In case of invalid user-provided arguments.
"""
self._assert_compile_was_called()
if (self._distribution_strategy and
distribution_strategy_context.in_cross_replica_context()):
raise NotImplementedError('`test_on_batch` is not supported for models '
@ -2191,8 +2196,6 @@ class Model(network.Network):
metrics_tensors = [
self._all_metrics_tensors[m] for m in self.metrics_names[1:]
]
if not self._is_compiled:
raise RuntimeError('You must compile your model before using it.')
self._check_trainable_weights_consistency()
# If we have re-compiled the loss/weighted metric sub-graphs then create
# train function even if one exists already. This is because
@ -2229,8 +2232,6 @@ class Model(network.Network):
metrics_tensors = [
self._all_metrics_tensors[m] for m in self.metrics_names[1:]
]
if not self._is_compiled:
raise RuntimeError('You must compile your model before using it.')
# If we have re-compiled the loss/weighted metric sub-graphs then create
# test function even if one exists already. This is because
# `_feed_sample_weights` list has been updated on re-copmpile.
@ -2555,13 +2556,9 @@ class Model(network.Network):
y_input = y
dict_inputs = isinstance(self.inputs, dict)
if y_input is not None:
if not self.optimizer:
raise RuntimeError('You must compile a model before '
'training/testing. '
'Use `model.compile(optimizer, loss)`.')
if not self._is_compiled:
# On-the-fly compilation of the model.
if not self._is_compiled and self.optimizer:
# On-the-fly compilation of the model.
if y_input is not None:
# We need to use `y` to set the model targets.
if training_utils.has_tensors(y_input):
y_input = training_utils.cast_if_floating_dtype(y_input)
@ -2582,31 +2579,34 @@ class Model(network.Network):
'You passed: y=' + str(y))
all_inputs.append(y_input)
# Typecheck that all inputs are *either* value *or* symbolic.
# TODO(fchollet): this check could be removed in Eager mode?
if any(tensor_util.is_tensor(v) for v in all_inputs):
if not all(tensor_util.is_tensor(v) for v in all_inputs):
raise ValueError('Do not pass inputs that mix Numpy arrays and '
'TensorFlow tensors. '
'You passed: x=' + str(x) + '; y=' + str(y))
# Typecheck that all inputs are *either* value *or* symbolic.
# TODO(fchollet): this check could be removed in Eager mode?
if any(tensor_util.is_tensor(v) for v in all_inputs):
if not all(tensor_util.is_tensor(v) for v in all_inputs):
raise ValueError('Do not pass inputs that mix Numpy arrays and '
'TensorFlow tensors. '
'You passed: x=' + str(x) + '; y=' + str(y))
if is_dataset or context.executing_eagerly():
target_tensors = None
else:
# Handle target tensors if any passed.
if is_dataset or context.executing_eagerly():
target_tensors = None
else:
# Handle target tensors if any passed.
if y_input is not None:
if not isinstance(y_input, (list, tuple)):
y_input = [y_input]
target_tensors = [v for v in y_input if _is_symbolic_tensor(v)]
is_compile_called = True
self.compile(
optimizer=self.optimizer,
loss=self.loss,
metrics=self._compile_metrics,
weighted_metrics=self._compile_weighted_metrics,
loss_weights=self.loss_weights,
target_tensors=target_tensors,
run_eagerly=self.run_eagerly,
cloning=self._cloning)
else:
target_tensors = None
is_compile_called = True
self.compile(
optimizer=self.optimizer,
loss=self.loss,
metrics=self._compile_metrics,
weighted_metrics=self._compile_weighted_metrics,
loss_weights=self.loss_weights,
target_tensors=target_tensors,
run_eagerly=self.run_eagerly,
cloning=self._cloning)
# In graph mode, if we had just set inputs and targets as symbolic tensors
# by invoking build and compile on the model respectively, we do not have to
@ -2921,6 +2921,16 @@ class Model(network.Network):
return int(self._ckpt_saved_epoch) + 1
return initial_epoch
def _assert_compile_was_called(self):
# Checks whether `compile` has been called. If it has been called,
# then the optimizer is set. This is different from whether the
# model is compiled
# (i.e. whether the model is built and its inputs/outputs are set).
if not self.optimizer:
raise RuntimeError('You must compile your model before '
'training/testing. '
'Use `model.compile(optimizer, loss)`.')
class DistributedCallbackModel(Model):
"""Model that is used for callbacks with tf.distribute.Strategy."""

View File

@ -339,12 +339,14 @@ def test_on_batch(model,
if len(inputs) and tensor_util.is_tensor(inputs[0]):
inputs = training_utils.cast_if_floating_to_model_input_dtypes(inputs,
model)
targets = training_utils.cast_if_floating_dtype(targets)
if targets:
targets = training_utils.cast_if_floating_dtype(targets)
else:
inputs = training_utils.cast_if_floating_to_model_input_dtypes(
[ops.convert_to_tensor(val) for val in inputs], model)
targets = training_utils.cast_if_floating_dtype(
[ops.convert_to_tensor(val) for val in targets])
if targets:
targets = training_utils.cast_if_floating_dtype(
[ops.convert_to_tensor(val) for val in targets])
if sample_weights:
sample_weights = [
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))

View File

@ -1201,8 +1201,7 @@ class TrainingTest(keras_parameterized.TestCase):
inputs = keras.Input(shape=(1,))
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
with self.assertRaisesRegex(
RuntimeError, 'You must compile a model before training/testing.'):
with self.assertRaisesRegex(RuntimeError, 'must compile your model'):
model.fit(np.ones((1, 1)), np.ones((1, 1)))
class MyModel(keras.Model):
@ -1212,10 +1211,7 @@ class TrainingTest(keras_parameterized.TestCase):
return x
model = MyModel()
model.compile(keras.optimizers.Adam(1e-3))
with self.assertRaisesRegex(RuntimeError,
'You must compile your model before using it.'):
with self.assertRaisesRegex(RuntimeError, 'must compile your model'):
model.fit(np.random.random((32, 1)), epochs=2)

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
@ -992,6 +993,28 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
loss = model.train_on_batch(x, y)
self.assertGreater(loss, 0.1)
def test_no_loss_in_compile(self):
class InternalLossModel(keras.Model):
def __init__(self):
super(InternalLossModel, self).__init__()
self.dense = keras.layers.Dense(1)
def call(self, inputs):
out = self.dense(inputs)
self.add_loss(math_ops.reduce_sum(out))
return out
model = InternalLossModel()
x = np.ones((10, 10))
model.predict(x)
model.compile(
optimizer='rmsprop',
run_eagerly=testing_utils.should_run_eagerly())
model.fit(x)
model.evaluate(x)
class GraphSpecificModelSubclassingTests(test.TestCase):