Cache DataHandler in `model.evaluate` to avoid function retracing between epochs in `model.fit`.

PiperOrigin-RevId: 316576844
Change-Id: Icf85ce6830b69a003c1f2ebf41f8c70258504afd
This commit is contained in:
Yanhui Liang 2020-06-15 17:13:45 -07:00 committed by TensorFlower Gardener
parent 304daeb37f
commit 0cc6210daa
2 changed files with 61 additions and 16 deletions

View File

@ -1040,6 +1040,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
data_adapter.train_validation_split(
(x, y, sample_weight), validation_split=validation_split))
if validation_data:
val_x, val_y, val_sample_weight = (
data_adapter.unpack_x_y_sample_weight(validation_data))
with self.distribute_strategy.scope(), \
training_utils.RespectCompiledTrainableState(self):
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
@ -1102,8 +1106,21 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
val_x, val_y, val_sample_weight = (
data_adapter.unpack_x_y_sample_weight(validation_data))
# Create data_handler for evaluation and cache it.
if getattr(self, '_eval_data_handler', None) is None:
self._eval_data_handler = data_adapter.DataHandler(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
steps_per_epoch=validation_steps,
initial_epoch=0,
epochs=1,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
model=self,
steps_per_execution=self._steps_per_execution)
val_logs = self.evaluate(
x=val_x,
y=val_y,
@ -1123,6 +1140,9 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
if self.stop_training:
break
# If eval data_hanlder exists, delete it after all epochs are done.
if getattr(self, '_eval_data_handler', None) is not None:
del self._eval_data_handler
callbacks.on_train_end(logs=training_logs)
return self.history
@ -1318,20 +1338,23 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
_disallow_inside_tf_function('evaluate')
with self.distribute_strategy.scope():
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
data_handler = data_adapter.DataHandler(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
initial_epoch=0,
epochs=1,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
model=self,
steps_per_execution=self._steps_per_execution)
if getattr(self, '_eval_data_handler', None) is not None:
data_handler = self._eval_data_handler
else:
# Creates a `tf.data.Dataset` and handles batch and epoch iteration.
data_handler = data_adapter.DataHandler(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
initial_epoch=0,
epochs=1,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
model=self,
steps_per_execution=self._steps_per_execution)
# Container that configures and calls `tf.keras.Callback`s.
if not isinstance(callbacks, callbacks_module.CallbackList):

View File

@ -59,6 +59,7 @@ from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.rmsprop import RMSPropOptimizer
try:
@ -3538,5 +3539,26 @@ class TestAutoUpdates(keras_parameterized.TestCase):
self.assertAllEqual(self.evaluate(bn.moving_variance), np.ones((10,)))
class TestFunctionTracing(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes(
always_skip_v1=True, always_skip_eager=True)
def test_no_tracing_between_epoch(self):
if sys.version_info[0] < 3:
self.skipTest('self.assertLogs() call is not available in Python 2.')
model = sequential.Sequential([layers_module.Dense(4, activation='relu')])
model.compile(loss='mse', optimizer='rmsprop')
x = np.random.random((10, 6))
y = np.random.random((10, 4))
logging.set_verbosity(1)
with self.assertLogs(level=1) as logs:
model.fit(x, y, epochs=10, batch_size=5, validation_data=(x, y))
new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function'
self.assertEqual(sum(new_func_graph in log for log in logs.output), 9)
if __name__ == '__main__':
test.main()