From 0cc6210daa35247daf7f4cc98c115de611d2d05f Mon Sep 17 00:00:00 2001 From: Yanhui Liang Date: Mon, 15 Jun 2020 17:13:45 -0700 Subject: [PATCH] Cache DataHandler in `model.evaluate` to avoid function retracing between epochs in `model.fit`. PiperOrigin-RevId: 316576844 Change-Id: Icf85ce6830b69a003c1f2ebf41f8c70258504afd --- tensorflow/python/keras/engine/training.py | 55 +++++++++++++------ .../python/keras/engine/training_test.py | 22 ++++++++ 2 files changed, 61 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index b7a4795d768..5567e1733a7 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -1040,6 +1040,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): data_adapter.train_validation_split( (x, y, sample_weight), validation_split=validation_split)) + if validation_data: + val_x, val_y, val_sample_weight = ( + data_adapter.unpack_x_y_sample_weight(validation_data)) + with self.distribute_strategy.scope(), \ training_utils.RespectCompiledTrainableState(self): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. @@ -1102,8 +1106,21 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): # Run validation. if validation_data and self._should_eval(epoch, validation_freq): - val_x, val_y, val_sample_weight = ( - data_adapter.unpack_x_y_sample_weight(validation_data)) + # Create data_handler for evaluation and cache it. + if getattr(self, '_eval_data_handler', None) is None: + self._eval_data_handler = data_adapter.DataHandler( + x=val_x, + y=val_y, + sample_weight=val_sample_weight, + batch_size=validation_batch_size or batch_size, + steps_per_epoch=validation_steps, + initial_epoch=0, + epochs=1, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + model=self, + steps_per_execution=self._steps_per_execution) val_logs = self.evaluate( x=val_x, y=val_y, @@ -1123,6 +1140,9 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): if self.stop_training: break + # If eval data_hanlder exists, delete it after all epochs are done. + if getattr(self, '_eval_data_handler', None) is not None: + del self._eval_data_handler callbacks.on_train_end(logs=training_logs) return self.history @@ -1318,20 +1338,23 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): _disallow_inside_tf_function('evaluate') with self.distribute_strategy.scope(): - # Creates a `tf.data.Dataset` and handles batch and epoch iteration. - data_handler = data_adapter.DataHandler( - x=x, - y=y, - sample_weight=sample_weight, - batch_size=batch_size, - steps_per_epoch=steps, - initial_epoch=0, - epochs=1, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing, - model=self, - steps_per_execution=self._steps_per_execution) + if getattr(self, '_eval_data_handler', None) is not None: + data_handler = self._eval_data_handler + else: + # Creates a `tf.data.Dataset` and handles batch and epoch iteration. + data_handler = data_adapter.DataHandler( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + initial_epoch=0, + epochs=1, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + model=self, + steps_per_execution=self._steps_per_execution) # Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index aa01463582c..5cf15926bfb 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -59,6 +59,7 @@ from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.rmsprop import RMSPropOptimizer try: @@ -3538,5 +3539,26 @@ class TestAutoUpdates(keras_parameterized.TestCase): self.assertAllEqual(self.evaluate(bn.moving_variance), np.ones((10,))) +class TestFunctionTracing(keras_parameterized.TestCase): + + @keras_parameterized.run_all_keras_modes( + always_skip_v1=True, always_skip_eager=True) + def test_no_tracing_between_epoch(self): + if sys.version_info[0] < 3: + self.skipTest('self.assertLogs() call is not available in Python 2.') + + model = sequential.Sequential([layers_module.Dense(4, activation='relu')]) + model.compile(loss='mse', optimizer='rmsprop') + x = np.random.random((10, 6)) + y = np.random.random((10, 4)) + + logging.set_verbosity(1) + with self.assertLogs(level=1) as logs: + model.fit(x, y, epochs=10, batch_size=5, validation_data=(x, y)) + + new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function' + self.assertEqual(sum(new_func_graph in log for log in logs.output), 9) + + if __name__ == '__main__': test.main()