Cache DataHandler in `model.evaluate` to avoid function retracing between epochs in `model.fit`.
PiperOrigin-RevId: 316576844 Change-Id: Icf85ce6830b69a003c1f2ebf41f8c70258504afd
This commit is contained in:
parent
304daeb37f
commit
0cc6210daa
|
@ -1040,6 +1040,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||||
data_adapter.train_validation_split(
|
data_adapter.train_validation_split(
|
||||||
(x, y, sample_weight), validation_split=validation_split))
|
(x, y, sample_weight), validation_split=validation_split))
|
||||||
|
|
||||||
|
if validation_data:
|
||||||
|
val_x, val_y, val_sample_weight = (
|
||||||
|
data_adapter.unpack_x_y_sample_weight(validation_data))
|
||||||
|
|
||||||
with self.distribute_strategy.scope(), \
|
with self.distribute_strategy.scope(), \
|
||||||
training_utils.RespectCompiledTrainableState(self):
|
training_utils.RespectCompiledTrainableState(self):
|
||||||
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
||||||
|
@ -1102,8 +1106,21 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||||
|
|
||||||
# Run validation.
|
# Run validation.
|
||||||
if validation_data and self._should_eval(epoch, validation_freq):
|
if validation_data and self._should_eval(epoch, validation_freq):
|
||||||
val_x, val_y, val_sample_weight = (
|
# Create data_handler for evaluation and cache it.
|
||||||
data_adapter.unpack_x_y_sample_weight(validation_data))
|
if getattr(self, '_eval_data_handler', None) is None:
|
||||||
|
self._eval_data_handler = data_adapter.DataHandler(
|
||||||
|
x=val_x,
|
||||||
|
y=val_y,
|
||||||
|
sample_weight=val_sample_weight,
|
||||||
|
batch_size=validation_batch_size or batch_size,
|
||||||
|
steps_per_epoch=validation_steps,
|
||||||
|
initial_epoch=0,
|
||||||
|
epochs=1,
|
||||||
|
max_queue_size=max_queue_size,
|
||||||
|
workers=workers,
|
||||||
|
use_multiprocessing=use_multiprocessing,
|
||||||
|
model=self,
|
||||||
|
steps_per_execution=self._steps_per_execution)
|
||||||
val_logs = self.evaluate(
|
val_logs = self.evaluate(
|
||||||
x=val_x,
|
x=val_x,
|
||||||
y=val_y,
|
y=val_y,
|
||||||
|
@ -1123,6 +1140,9 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||||
if self.stop_training:
|
if self.stop_training:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# If eval data_hanlder exists, delete it after all epochs are done.
|
||||||
|
if getattr(self, '_eval_data_handler', None) is not None:
|
||||||
|
del self._eval_data_handler
|
||||||
callbacks.on_train_end(logs=training_logs)
|
callbacks.on_train_end(logs=training_logs)
|
||||||
return self.history
|
return self.history
|
||||||
|
|
||||||
|
@ -1318,20 +1338,23 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||||
_disallow_inside_tf_function('evaluate')
|
_disallow_inside_tf_function('evaluate')
|
||||||
|
|
||||||
with self.distribute_strategy.scope():
|
with self.distribute_strategy.scope():
|
||||||
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
if getattr(self, '_eval_data_handler', None) is not None:
|
||||||
data_handler = data_adapter.DataHandler(
|
data_handler = self._eval_data_handler
|
||||||
x=x,
|
else:
|
||||||
y=y,
|
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
|
||||||
sample_weight=sample_weight,
|
data_handler = data_adapter.DataHandler(
|
||||||
batch_size=batch_size,
|
x=x,
|
||||||
steps_per_epoch=steps,
|
y=y,
|
||||||
initial_epoch=0,
|
sample_weight=sample_weight,
|
||||||
epochs=1,
|
batch_size=batch_size,
|
||||||
max_queue_size=max_queue_size,
|
steps_per_epoch=steps,
|
||||||
workers=workers,
|
initial_epoch=0,
|
||||||
use_multiprocessing=use_multiprocessing,
|
epochs=1,
|
||||||
model=self,
|
max_queue_size=max_queue_size,
|
||||||
steps_per_execution=self._steps_per_execution)
|
workers=workers,
|
||||||
|
use_multiprocessing=use_multiprocessing,
|
||||||
|
model=self,
|
||||||
|
steps_per_execution=self._steps_per_execution)
|
||||||
|
|
||||||
# Container that configures and calls `tf.keras.Callback`s.
|
# Container that configures and calls `tf.keras.Callback`s.
|
||||||
if not isinstance(callbacks, callbacks_module.CallbackList):
|
if not isinstance(callbacks, callbacks_module.CallbackList):
|
||||||
|
|
|
@ -59,6 +59,7 @@ from tensorflow.python.ops import template
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -3538,5 +3539,26 @@ class TestAutoUpdates(keras_parameterized.TestCase):
|
||||||
self.assertAllEqual(self.evaluate(bn.moving_variance), np.ones((10,)))
|
self.assertAllEqual(self.evaluate(bn.moving_variance), np.ones((10,)))
|
||||||
|
|
||||||
|
|
||||||
|
class TestFunctionTracing(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes(
|
||||||
|
always_skip_v1=True, always_skip_eager=True)
|
||||||
|
def test_no_tracing_between_epoch(self):
|
||||||
|
if sys.version_info[0] < 3:
|
||||||
|
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
||||||
|
|
||||||
|
model = sequential.Sequential([layers_module.Dense(4, activation='relu')])
|
||||||
|
model.compile(loss='mse', optimizer='rmsprop')
|
||||||
|
x = np.random.random((10, 6))
|
||||||
|
y = np.random.random((10, 4))
|
||||||
|
|
||||||
|
logging.set_verbosity(1)
|
||||||
|
with self.assertLogs(level=1) as logs:
|
||||||
|
model.fit(x, y, epochs=10, batch_size=5, validation_data=(x, y))
|
||||||
|
|
||||||
|
new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function'
|
||||||
|
self.assertEqual(sum(new_func_graph in log for log in logs.output), 9)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|
Loading…
Reference in New Issue