Make it possible for a subclassed model to be compiled and trained with no
external loss. PiperOrigin-RevId: 248035588
This commit is contained in:
parent
20359d70e4
commit
09ca187c09
@ -592,6 +592,7 @@ class Model(network.Network):
|
|||||||
epochs = kwargs.pop('nb_epoch')
|
epochs = kwargs.pop('nb_epoch')
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
|
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
|
||||||
|
self._assert_compile_was_called()
|
||||||
|
|
||||||
# Case 1: distribution strategy.
|
# Case 1: distribution strategy.
|
||||||
if self._distribution_strategy:
|
if self._distribution_strategy:
|
||||||
@ -862,6 +863,8 @@ class Model(network.Network):
|
|||||||
ValueError: in case of invalid arguments.
|
ValueError: in case of invalid arguments.
|
||||||
"""
|
"""
|
||||||
_keras_api_gauge.get_cell('evaluate').set(True)
|
_keras_api_gauge.get_cell('evaluate').set(True)
|
||||||
|
self._assert_compile_was_called()
|
||||||
|
|
||||||
# Case 1: distribution strategy.
|
# Case 1: distribution strategy.
|
||||||
if self._distribution_strategy:
|
if self._distribution_strategy:
|
||||||
if K.in_multi_worker_mode():
|
if K.in_multi_worker_mode():
|
||||||
@ -1133,6 +1136,7 @@ class Model(network.Network):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: In case of invalid user-provided arguments.
|
ValueError: In case of invalid user-provided arguments.
|
||||||
"""
|
"""
|
||||||
|
self._assert_compile_was_called()
|
||||||
# If at this point we are in the replica context, then it is okay to execute
|
# If at this point we are in the replica context, then it is okay to execute
|
||||||
# the Eager code path. The expected way to get here is to call `fit` that
|
# the Eager code path. The expected way to get here is to call `fit` that
|
||||||
# calls `train_on_batch` on each replica.
|
# calls `train_on_batch` on each replica.
|
||||||
@ -1213,6 +1217,7 @@ class Model(network.Network):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: In case of invalid user-provided arguments.
|
ValueError: In case of invalid user-provided arguments.
|
||||||
"""
|
"""
|
||||||
|
self._assert_compile_was_called()
|
||||||
if (self._distribution_strategy and
|
if (self._distribution_strategy and
|
||||||
distribution_strategy_context.in_cross_replica_context()):
|
distribution_strategy_context.in_cross_replica_context()):
|
||||||
raise NotImplementedError('`test_on_batch` is not supported for models '
|
raise NotImplementedError('`test_on_batch` is not supported for models '
|
||||||
@ -2191,8 +2196,6 @@ class Model(network.Network):
|
|||||||
metrics_tensors = [
|
metrics_tensors = [
|
||||||
self._all_metrics_tensors[m] for m in self.metrics_names[1:]
|
self._all_metrics_tensors[m] for m in self.metrics_names[1:]
|
||||||
]
|
]
|
||||||
if not self._is_compiled:
|
|
||||||
raise RuntimeError('You must compile your model before using it.')
|
|
||||||
self._check_trainable_weights_consistency()
|
self._check_trainable_weights_consistency()
|
||||||
# If we have re-compiled the loss/weighted metric sub-graphs then create
|
# If we have re-compiled the loss/weighted metric sub-graphs then create
|
||||||
# train function even if one exists already. This is because
|
# train function even if one exists already. This is because
|
||||||
@ -2229,8 +2232,6 @@ class Model(network.Network):
|
|||||||
metrics_tensors = [
|
metrics_tensors = [
|
||||||
self._all_metrics_tensors[m] for m in self.metrics_names[1:]
|
self._all_metrics_tensors[m] for m in self.metrics_names[1:]
|
||||||
]
|
]
|
||||||
if not self._is_compiled:
|
|
||||||
raise RuntimeError('You must compile your model before using it.')
|
|
||||||
# If we have re-compiled the loss/weighted metric sub-graphs then create
|
# If we have re-compiled the loss/weighted metric sub-graphs then create
|
||||||
# test function even if one exists already. This is because
|
# test function even if one exists already. This is because
|
||||||
# `_feed_sample_weights` list has been updated on re-copmpile.
|
# `_feed_sample_weights` list has been updated on re-copmpile.
|
||||||
@ -2555,13 +2556,9 @@ class Model(network.Network):
|
|||||||
y_input = y
|
y_input = y
|
||||||
dict_inputs = isinstance(self.inputs, dict)
|
dict_inputs = isinstance(self.inputs, dict)
|
||||||
|
|
||||||
if y_input is not None:
|
if not self._is_compiled and self.optimizer:
|
||||||
if not self.optimizer:
|
# On-the-fly compilation of the model.
|
||||||
raise RuntimeError('You must compile a model before '
|
if y_input is not None:
|
||||||
'training/testing. '
|
|
||||||
'Use `model.compile(optimizer, loss)`.')
|
|
||||||
if not self._is_compiled:
|
|
||||||
# On-the-fly compilation of the model.
|
|
||||||
# We need to use `y` to set the model targets.
|
# We need to use `y` to set the model targets.
|
||||||
if training_utils.has_tensors(y_input):
|
if training_utils.has_tensors(y_input):
|
||||||
y_input = training_utils.cast_if_floating_dtype(y_input)
|
y_input = training_utils.cast_if_floating_dtype(y_input)
|
||||||
@ -2582,31 +2579,34 @@ class Model(network.Network):
|
|||||||
'You passed: y=' + str(y))
|
'You passed: y=' + str(y))
|
||||||
all_inputs.append(y_input)
|
all_inputs.append(y_input)
|
||||||
|
|
||||||
# Typecheck that all inputs are *either* value *or* symbolic.
|
# Typecheck that all inputs are *either* value *or* symbolic.
|
||||||
# TODO(fchollet): this check could be removed in Eager mode?
|
# TODO(fchollet): this check could be removed in Eager mode?
|
||||||
if any(tensor_util.is_tensor(v) for v in all_inputs):
|
if any(tensor_util.is_tensor(v) for v in all_inputs):
|
||||||
if not all(tensor_util.is_tensor(v) for v in all_inputs):
|
if not all(tensor_util.is_tensor(v) for v in all_inputs):
|
||||||
raise ValueError('Do not pass inputs that mix Numpy arrays and '
|
raise ValueError('Do not pass inputs that mix Numpy arrays and '
|
||||||
'TensorFlow tensors. '
|
'TensorFlow tensors. '
|
||||||
'You passed: x=' + str(x) + '; y=' + str(y))
|
'You passed: x=' + str(x) + '; y=' + str(y))
|
||||||
|
|
||||||
if is_dataset or context.executing_eagerly():
|
if is_dataset or context.executing_eagerly():
|
||||||
target_tensors = None
|
target_tensors = None
|
||||||
else:
|
else:
|
||||||
# Handle target tensors if any passed.
|
# Handle target tensors if any passed.
|
||||||
|
if y_input is not None:
|
||||||
if not isinstance(y_input, (list, tuple)):
|
if not isinstance(y_input, (list, tuple)):
|
||||||
y_input = [y_input]
|
y_input = [y_input]
|
||||||
target_tensors = [v for v in y_input if _is_symbolic_tensor(v)]
|
target_tensors = [v for v in y_input if _is_symbolic_tensor(v)]
|
||||||
is_compile_called = True
|
else:
|
||||||
self.compile(
|
target_tensors = None
|
||||||
optimizer=self.optimizer,
|
is_compile_called = True
|
||||||
loss=self.loss,
|
self.compile(
|
||||||
metrics=self._compile_metrics,
|
optimizer=self.optimizer,
|
||||||
weighted_metrics=self._compile_weighted_metrics,
|
loss=self.loss,
|
||||||
loss_weights=self.loss_weights,
|
metrics=self._compile_metrics,
|
||||||
target_tensors=target_tensors,
|
weighted_metrics=self._compile_weighted_metrics,
|
||||||
run_eagerly=self.run_eagerly,
|
loss_weights=self.loss_weights,
|
||||||
cloning=self._cloning)
|
target_tensors=target_tensors,
|
||||||
|
run_eagerly=self.run_eagerly,
|
||||||
|
cloning=self._cloning)
|
||||||
|
|
||||||
# In graph mode, if we had just set inputs and targets as symbolic tensors
|
# In graph mode, if we had just set inputs and targets as symbolic tensors
|
||||||
# by invoking build and compile on the model respectively, we do not have to
|
# by invoking build and compile on the model respectively, we do not have to
|
||||||
@ -2921,6 +2921,16 @@ class Model(network.Network):
|
|||||||
return int(self._ckpt_saved_epoch) + 1
|
return int(self._ckpt_saved_epoch) + 1
|
||||||
return initial_epoch
|
return initial_epoch
|
||||||
|
|
||||||
|
def _assert_compile_was_called(self):
|
||||||
|
# Checks whether `compile` has been called. If it has been called,
|
||||||
|
# then the optimizer is set. This is different from whether the
|
||||||
|
# model is compiled
|
||||||
|
# (i.e. whether the model is built and its inputs/outputs are set).
|
||||||
|
if not self.optimizer:
|
||||||
|
raise RuntimeError('You must compile your model before '
|
||||||
|
'training/testing. '
|
||||||
|
'Use `model.compile(optimizer, loss)`.')
|
||||||
|
|
||||||
|
|
||||||
class DistributedCallbackModel(Model):
|
class DistributedCallbackModel(Model):
|
||||||
"""Model that is used for callbacks with tf.distribute.Strategy."""
|
"""Model that is used for callbacks with tf.distribute.Strategy."""
|
||||||
|
@ -339,12 +339,14 @@ def test_on_batch(model,
|
|||||||
if len(inputs) and tensor_util.is_tensor(inputs[0]):
|
if len(inputs) and tensor_util.is_tensor(inputs[0]):
|
||||||
inputs = training_utils.cast_if_floating_to_model_input_dtypes(inputs,
|
inputs = training_utils.cast_if_floating_to_model_input_dtypes(inputs,
|
||||||
model)
|
model)
|
||||||
targets = training_utils.cast_if_floating_dtype(targets)
|
if targets:
|
||||||
|
targets = training_utils.cast_if_floating_dtype(targets)
|
||||||
else:
|
else:
|
||||||
inputs = training_utils.cast_if_floating_to_model_input_dtypes(
|
inputs = training_utils.cast_if_floating_to_model_input_dtypes(
|
||||||
[ops.convert_to_tensor(val) for val in inputs], model)
|
[ops.convert_to_tensor(val) for val in inputs], model)
|
||||||
targets = training_utils.cast_if_floating_dtype(
|
if targets:
|
||||||
[ops.convert_to_tensor(val) for val in targets])
|
targets = training_utils.cast_if_floating_dtype(
|
||||||
|
[ops.convert_to_tensor(val) for val in targets])
|
||||||
if sample_weights:
|
if sample_weights:
|
||||||
sample_weights = [
|
sample_weights = [
|
||||||
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
|
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
|
||||||
|
@ -1201,8 +1201,7 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
inputs = keras.Input(shape=(1,))
|
inputs = keras.Input(shape=(1,))
|
||||||
outputs = keras.layers.Dense(1)(inputs)
|
outputs = keras.layers.Dense(1)(inputs)
|
||||||
model = keras.Model(inputs, outputs)
|
model = keras.Model(inputs, outputs)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(RuntimeError, 'must compile your model'):
|
||||||
RuntimeError, 'You must compile a model before training/testing.'):
|
|
||||||
model.fit(np.ones((1, 1)), np.ones((1, 1)))
|
model.fit(np.ones((1, 1)), np.ones((1, 1)))
|
||||||
|
|
||||||
class MyModel(keras.Model):
|
class MyModel(keras.Model):
|
||||||
@ -1212,10 +1211,7 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
model = MyModel()
|
model = MyModel()
|
||||||
model.compile(keras.optimizers.Adam(1e-3))
|
with self.assertRaisesRegex(RuntimeError, 'must compile your model'):
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError,
|
|
||||||
'You must compile your model before using it.'):
|
|
||||||
model.fit(np.random.random((32, 1)), epochs=2)
|
model.fit(np.random.random((32, 1)), epochs=2)
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import test_util
|
|||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
@ -992,6 +993,28 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
|
|||||||
loss = model.train_on_batch(x, y)
|
loss = model.train_on_batch(x, y)
|
||||||
self.assertGreater(loss, 0.1)
|
self.assertGreater(loss, 0.1)
|
||||||
|
|
||||||
|
def test_no_loss_in_compile(self):
|
||||||
|
|
||||||
|
class InternalLossModel(keras.Model):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(InternalLossModel, self).__init__()
|
||||||
|
self.dense = keras.layers.Dense(1)
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
out = self.dense(inputs)
|
||||||
|
self.add_loss(math_ops.reduce_sum(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
model = InternalLossModel()
|
||||||
|
x = np.ones((10, 10))
|
||||||
|
model.predict(x)
|
||||||
|
model.compile(
|
||||||
|
optimizer='rmsprop',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
model.fit(x)
|
||||||
|
model.evaluate(x)
|
||||||
|
|
||||||
|
|
||||||
class GraphSpecificModelSubclassingTests(test.TestCase):
|
class GraphSpecificModelSubclassingTests(test.TestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user