Cache DataHandler in `model.evaluate` to avoid function retracing between epochs in `model.fit`.
PiperOrigin-RevId: 316576844 Change-Id: Icf85ce6830b69a003c1f2ebf41f8c70258504afd
This commit is contained in:
parent
304daeb37f
commit
0cc6210daa
|
@ -1040,6 +1040,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||
data_adapter.train_validation_split(
|
||||
(x, y, sample_weight), validation_split=validation_split))
|
||||
|
||||
if validation_data:
|
||||
val_x, val_y, val_sample_weight = (
|
||||
data_adapter.unpack_x_y_sample_weight(validation_data))
|
||||
|
||||
with self.distribute_strategy.scope(), \
|
||||
training_utils.RespectCompiledTrainableState(self):
|
||||
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
||||
|
@ -1102,8 +1106,21 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||
|
||||
# Run validation.
|
||||
if validation_data and self._should_eval(epoch, validation_freq):
|
||||
val_x, val_y, val_sample_weight = (
|
||||
data_adapter.unpack_x_y_sample_weight(validation_data))
|
||||
# Create data_handler for evaluation and cache it.
|
||||
if getattr(self, '_eval_data_handler', None) is None:
|
||||
self._eval_data_handler = data_adapter.DataHandler(
|
||||
x=val_x,
|
||||
y=val_y,
|
||||
sample_weight=val_sample_weight,
|
||||
batch_size=validation_batch_size or batch_size,
|
||||
steps_per_epoch=validation_steps,
|
||||
initial_epoch=0,
|
||||
epochs=1,
|
||||
max_queue_size=max_queue_size,
|
||||
workers=workers,
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
model=self,
|
||||
steps_per_execution=self._steps_per_execution)
|
||||
val_logs = self.evaluate(
|
||||
x=val_x,
|
||||
y=val_y,
|
||||
|
@ -1123,6 +1140,9 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||
if self.stop_training:
|
||||
break
|
||||
|
||||
# If eval data_hanlder exists, delete it after all epochs are done.
|
||||
if getattr(self, '_eval_data_handler', None) is not None:
|
||||
del self._eval_data_handler
|
||||
callbacks.on_train_end(logs=training_logs)
|
||||
return self.history
|
||||
|
||||
|
@ -1318,20 +1338,23 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
|||
_disallow_inside_tf_function('evaluate')
|
||||
|
||||
with self.distribute_strategy.scope():
|
||||
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
||||
data_handler = data_adapter.DataHandler(
|
||||
x=x,
|
||||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
batch_size=batch_size,
|
||||
steps_per_epoch=steps,
|
||||
initial_epoch=0,
|
||||
epochs=1,
|
||||
max_queue_size=max_queue_size,
|
||||
workers=workers,
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
model=self,
|
||||
steps_per_execution=self._steps_per_execution)
|
||||
if getattr(self, '_eval_data_handler', None) is not None:
|
||||
data_handler = self._eval_data_handler
|
||||
else:
|
||||
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
||||
data_handler = data_adapter.DataHandler(
|
||||
x=x,
|
||||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
batch_size=batch_size,
|
||||
steps_per_epoch=steps,
|
||||
initial_epoch=0,
|
||||
epochs=1,
|
||||
max_queue_size=max_queue_size,
|
||||
workers=workers,
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
model=self,
|
||||
steps_per_execution=self._steps_per_execution)
|
||||
|
||||
# Container that configures and calls `tf.keras.Callback`s.
|
||||
if not isinstance(callbacks, callbacks_module.CallbackList):
|
||||
|
|
|
@ -59,6 +59,7 @@ from tensorflow.python.ops import template
|
|||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
try:
|
||||
|
@ -3538,5 +3539,26 @@ class TestAutoUpdates(keras_parameterized.TestCase):
|
|||
self.assertAllEqual(self.evaluate(bn.moving_variance), np.ones((10,)))
|
||||
|
||||
|
||||
class TestFunctionTracing(keras_parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(
|
||||
always_skip_v1=True, always_skip_eager=True)
|
||||
def test_no_tracing_between_epoch(self):
|
||||
if sys.version_info[0] < 3:
|
||||
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
||||
|
||||
model = sequential.Sequential([layers_module.Dense(4, activation='relu')])
|
||||
model.compile(loss='mse', optimizer='rmsprop')
|
||||
x = np.random.random((10, 6))
|
||||
y = np.random.random((10, 4))
|
||||
|
||||
logging.set_verbosity(1)
|
||||
with self.assertLogs(level=1) as logs:
|
||||
model.fit(x, y, epochs=10, batch_size=5, validation_data=(x, y))
|
||||
|
||||
new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function'
|
||||
self.assertEqual(sum(new_func_graph in log for log in logs.output), 9)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue