Make the individual call of Model.evaluate use no cached data, while Model.fit still uses cached data for evaluation.
				
					
				
			PiperOrigin-RevId: 337194749 Change-Id: I1a99eadac1831918a21ab994a1d7a5e84b0623c6
This commit is contained in:
		
							parent
							
								
									2882df7cf2
								
							
						
					
					
						commit
						5cb9129e4b
					
				| @ -225,6 +225,9 @@ | ||||
|         argument. | ||||
|     *   Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints | ||||
|         with the same implementation as their `tf.losses` equivalent. | ||||
|     *   For Keras model, the individual call of `Model.evaluate` uses no cached | ||||
|         data for evaluation, while `Model.fit` uses cached data when | ||||
|         `validation_data` arg is provided for better performance. | ||||
| *   `tf.function` / AutoGraph: | ||||
|     *   Added `experimental_follow_type_hints` argument for `tf.function`. When | ||||
|         True, the function may use type annotations to optimize the tracing | ||||
|  | ||||
| @ -58,6 +58,7 @@ from tensorflow.python.keras.saving.saved_model import json_utils | ||||
| from tensorflow.python.keras.saving.saved_model import model_serialization | ||||
| from tensorflow.python.keras.utils import generic_utils | ||||
| from tensorflow.python.keras.utils import layer_utils | ||||
| from tensorflow.python.keras.utils import tf_inspect | ||||
| from tensorflow.python.keras.utils import tf_utils | ||||
| from tensorflow.python.keras.utils import version_utils | ||||
| from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite | ||||
| @ -1099,6 +1100,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): | ||||
|         if validation_data and self._should_eval(epoch, validation_freq): | ||||
|           # Create data_handler for evaluation and cache it. | ||||
|           if getattr(self, '_eval_data_handler', None) is None: | ||||
|             self._fit_frame = tf_inspect.currentframe() | ||||
|             self._eval_data_handler = data_adapter.DataHandler( | ||||
|                 x=val_x, | ||||
|                 y=val_y, | ||||
| @ -1134,6 +1136,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): | ||||
|       # If eval data_hanlder exists, delete it after all epochs are done. | ||||
|       if getattr(self, '_eval_data_handler', None) is not None: | ||||
|         del self._eval_data_handler | ||||
|         del self._fit_frame | ||||
|       callbacks.on_train_end(logs=training_logs) | ||||
|       return self.history | ||||
| 
 | ||||
| @ -1327,7 +1330,10 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): | ||||
|     _disallow_inside_tf_function('evaluate') | ||||
| 
 | ||||
|     with self.distribute_strategy.scope(): | ||||
|       if getattr(self, '_eval_data_handler', None) is not None: | ||||
|       # Use cached evaluation data only when it's called in `Model.fit` | ||||
|       if (getattr(self, '_fit_frame', None) is not None | ||||
|           and tf_inspect.currentframe().f_back is self._fit_frame | ||||
|           and getattr(self, '_eval_data_handler', None) is not None): | ||||
|         data_handler = self._eval_data_handler | ||||
|       else: | ||||
|         # Creates a `tf.data.Dataset` and handles batch and epoch iteration. | ||||
|  | ||||
| @ -3642,16 +3642,20 @@ class TestAutoUpdates(keras_parameterized.TestCase): | ||||
| 
 | ||||
| class TestFunctionTracing(keras_parameterized.TestCase): | ||||
| 
 | ||||
|   def _seq_model_and_data(self): | ||||
|     model = sequential.Sequential([layers_module.Dense(4, activation='relu')]) | ||||
|     model.compile(loss='mse', optimizer='rmsprop') | ||||
|     x = np.random.random((10, 6)) | ||||
|     y = np.random.random((10, 4)) | ||||
|     return model, x, y | ||||
| 
 | ||||
|   @keras_parameterized.run_all_keras_modes( | ||||
|       always_skip_v1=True, always_skip_eager=True) | ||||
|   def test_no_tracing_between_epoch(self): | ||||
|     if sys.version_info[0] < 3: | ||||
|       self.skipTest('self.assertLogs() call is not available in Python 2.') | ||||
| 
 | ||||
|     model = sequential.Sequential([layers_module.Dense(4, activation='relu')]) | ||||
|     model.compile(loss='mse', optimizer='rmsprop') | ||||
|     x = np.random.random((10, 6)) | ||||
|     y = np.random.random((10, 4)) | ||||
|     model, x, y = self._seq_model_and_data() | ||||
| 
 | ||||
|     logging.set_verbosity(1) | ||||
|     with self.assertLogs(level=1) as logs: | ||||
| @ -3660,6 +3664,21 @@ class TestFunctionTracing(keras_parameterized.TestCase): | ||||
|     new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function' | ||||
|     self.assertEqual(sum(new_func_graph in log for log in logs.output), 9) | ||||
| 
 | ||||
|   @keras_parameterized.run_all_keras_modes( | ||||
|       always_skip_v1=True, always_skip_eager=True) | ||||
|   def test_evaluate_no_cached_data(self): | ||||
|     if sys.version_info[0] < 3: | ||||
|       self.skipTest('self.assertLogs() call is not available in Python 2.') | ||||
| 
 | ||||
|     model, x, y = self._seq_model_and_data() | ||||
| 
 | ||||
|     new_func_graph = 'INFO:absl:Creating new FuncGraph for Python function' | ||||
|     logging.set_verbosity(1) | ||||
|     with self.assertLogs(level=1) as eval_logs: | ||||
|       for _ in range(6): | ||||
|         model.evaluate(x, y, batch_size=5) | ||||
|     self.assertEqual(sum(new_func_graph in log for log in eval_logs.output), 20) | ||||
| 
 | ||||
| 
 | ||||
| class TestBuildCustomModel(keras_parameterized.TestCase): | ||||
| 
 | ||||
|  | ||||
| @ -90,6 +90,11 @@ else: | ||||
|     return _convert_maybe_argspec_to_fullargspec(getargspec(target)) | ||||
| 
 | ||||
| 
 | ||||
| def currentframe(): | ||||
|   """TFDecorator-aware replacement for inspect.currentframe.""" | ||||
|   return _inspect.stack()[1][0] | ||||
| 
 | ||||
| 
 | ||||
| def getargspec(obj): | ||||
|   """TFDecorator-aware replacement for `inspect.getargspec`. | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user