Make it possible for a subclassed model to be compiled and trained with no
external loss. PiperOrigin-RevId: 248035588
This commit is contained in:
parent
20359d70e4
commit
09ca187c09
@ -592,6 +592,7 @@ class Model(network.Network):
|
||||
epochs = kwargs.pop('nb_epoch')
|
||||
if kwargs:
|
||||
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
|
||||
self._assert_compile_was_called()
|
||||
|
||||
# Case 1: distribution strategy.
|
||||
if self._distribution_strategy:
|
||||
@ -862,6 +863,8 @@ class Model(network.Network):
|
||||
ValueError: in case of invalid arguments.
|
||||
"""
|
||||
_keras_api_gauge.get_cell('evaluate').set(True)
|
||||
self._assert_compile_was_called()
|
||||
|
||||
# Case 1: distribution strategy.
|
||||
if self._distribution_strategy:
|
||||
if K.in_multi_worker_mode():
|
||||
@ -1133,6 +1136,7 @@ class Model(network.Network):
|
||||
Raises:
|
||||
ValueError: In case of invalid user-provided arguments.
|
||||
"""
|
||||
self._assert_compile_was_called()
|
||||
# If at this point we are in the replica context, then it is okay to execute
|
||||
# the Eager code path. The expected way to get here is to call `fit` that
|
||||
# calls `train_on_batch` on each replica.
|
||||
@ -1213,6 +1217,7 @@ class Model(network.Network):
|
||||
Raises:
|
||||
ValueError: In case of invalid user-provided arguments.
|
||||
"""
|
||||
self._assert_compile_was_called()
|
||||
if (self._distribution_strategy and
|
||||
distribution_strategy_context.in_cross_replica_context()):
|
||||
raise NotImplementedError('`test_on_batch` is not supported for models '
|
||||
@ -2191,8 +2196,6 @@ class Model(network.Network):
|
||||
metrics_tensors = [
|
||||
self._all_metrics_tensors[m] for m in self.metrics_names[1:]
|
||||
]
|
||||
if not self._is_compiled:
|
||||
raise RuntimeError('You must compile your model before using it.')
|
||||
self._check_trainable_weights_consistency()
|
||||
# If we have re-compiled the loss/weighted metric sub-graphs then create
|
||||
# train function even if one exists already. This is because
|
||||
@ -2229,8 +2232,6 @@ class Model(network.Network):
|
||||
metrics_tensors = [
|
||||
self._all_metrics_tensors[m] for m in self.metrics_names[1:]
|
||||
]
|
||||
if not self._is_compiled:
|
||||
raise RuntimeError('You must compile your model before using it.')
|
||||
# If we have re-compiled the loss/weighted metric sub-graphs then create
|
||||
# test function even if one exists already. This is because
|
||||
# `_feed_sample_weights` list has been updated on re-copmpile.
|
||||
@ -2555,13 +2556,9 @@ class Model(network.Network):
|
||||
y_input = y
|
||||
dict_inputs = isinstance(self.inputs, dict)
|
||||
|
||||
if y_input is not None:
|
||||
if not self.optimizer:
|
||||
raise RuntimeError('You must compile a model before '
|
||||
'training/testing. '
|
||||
'Use `model.compile(optimizer, loss)`.')
|
||||
if not self._is_compiled:
|
||||
if not self._is_compiled and self.optimizer:
|
||||
# On-the-fly compilation of the model.
|
||||
if y_input is not None:
|
||||
# We need to use `y` to set the model targets.
|
||||
if training_utils.has_tensors(y_input):
|
||||
y_input = training_utils.cast_if_floating_dtype(y_input)
|
||||
@ -2594,9 +2591,12 @@ class Model(network.Network):
|
||||
target_tensors = None
|
||||
else:
|
||||
# Handle target tensors if any passed.
|
||||
if y_input is not None:
|
||||
if not isinstance(y_input, (list, tuple)):
|
||||
y_input = [y_input]
|
||||
target_tensors = [v for v in y_input if _is_symbolic_tensor(v)]
|
||||
else:
|
||||
target_tensors = None
|
||||
is_compile_called = True
|
||||
self.compile(
|
||||
optimizer=self.optimizer,
|
||||
@ -2921,6 +2921,16 @@ class Model(network.Network):
|
||||
return int(self._ckpt_saved_epoch) + 1
|
||||
return initial_epoch
|
||||
|
||||
def _assert_compile_was_called(self):
|
||||
# Checks whether `compile` has been called. If it has been called,
|
||||
# then the optimizer is set. This is different from whether the
|
||||
# model is compiled
|
||||
# (i.e. whether the model is built and its inputs/outputs are set).
|
||||
if not self.optimizer:
|
||||
raise RuntimeError('You must compile your model before '
|
||||
'training/testing. '
|
||||
'Use `model.compile(optimizer, loss)`.')
|
||||
|
||||
|
||||
class DistributedCallbackModel(Model):
|
||||
"""Model that is used for callbacks with tf.distribute.Strategy."""
|
||||
|
@ -339,10 +339,12 @@ def test_on_batch(model,
|
||||
if len(inputs) and tensor_util.is_tensor(inputs[0]):
|
||||
inputs = training_utils.cast_if_floating_to_model_input_dtypes(inputs,
|
||||
model)
|
||||
if targets:
|
||||
targets = training_utils.cast_if_floating_dtype(targets)
|
||||
else:
|
||||
inputs = training_utils.cast_if_floating_to_model_input_dtypes(
|
||||
[ops.convert_to_tensor(val) for val in inputs], model)
|
||||
if targets:
|
||||
targets = training_utils.cast_if_floating_dtype(
|
||||
[ops.convert_to_tensor(val) for val in targets])
|
||||
if sample_weights:
|
||||
|
@ -1201,8 +1201,7 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
inputs = keras.Input(shape=(1,))
|
||||
outputs = keras.layers.Dense(1)(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'You must compile a model before training/testing.'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'must compile your model'):
|
||||
model.fit(np.ones((1, 1)), np.ones((1, 1)))
|
||||
|
||||
class MyModel(keras.Model):
|
||||
@ -1212,10 +1211,7 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
return x
|
||||
|
||||
model = MyModel()
|
||||
model.compile(keras.optimizers.Adam(1e-3))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'You must compile your model before using it.'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'must compile your model'):
|
||||
model.fit(np.random.random((32, 1)), epochs=2)
|
||||
|
||||
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -992,6 +993,28 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
|
||||
loss = model.train_on_batch(x, y)
|
||||
self.assertGreater(loss, 0.1)
|
||||
|
||||
def test_no_loss_in_compile(self):
|
||||
|
||||
class InternalLossModel(keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(InternalLossModel, self).__init__()
|
||||
self.dense = keras.layers.Dense(1)
|
||||
|
||||
def call(self, inputs):
|
||||
out = self.dense(inputs)
|
||||
self.add_loss(math_ops.reduce_sum(out))
|
||||
return out
|
||||
|
||||
model = InternalLossModel()
|
||||
x = np.ones((10, 10))
|
||||
model.predict(x)
|
||||
model.compile(
|
||||
optimizer='rmsprop',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
model.fit(x)
|
||||
model.evaluate(x)
|
||||
|
||||
|
||||
class GraphSpecificModelSubclassingTests(test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user