Fix another profiler is running error when calling model.fit twice in graph mode.
Remove context.executing_eagerly() check for profiler, so that it works in graph mode. Use tf.cond when comparing tensors. Github issue: https://github.com/tensorflow/tensorflow/issues/37543 PiperOrigin-RevId: 300882803 Change-Id: I6ed4f9212984aa896d049fb2726c7868f3d427ea
This commit is contained in:
		
							parent
							
								
									1646cd9bb9
								
							
						
					
					
						commit
						01641aee30
					
				| @ -46,6 +46,7 @@ from tensorflow.python.keras.utils.generic_utils import Progbar | ||||
| from tensorflow.python.keras.utils.mode_keys import ModeKeys | ||||
| from tensorflow.python.lib.io import file_io | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import control_flow_ops | ||||
| from tensorflow.python.ops import math_ops | ||||
| from tensorflow.python.ops import summary_ops_v2 | ||||
| from tensorflow.python.ops import variables | ||||
| @ -1948,9 +1949,7 @@ class TensorBoard(Callback): | ||||
|   def on_train_begin(self, logs=None): | ||||
|     self._init_batch_steps() | ||||
|     if self._start_batch == 1: | ||||
|       summary_ops_v2.trace_on(graph=True, profiler=False) | ||||
|       profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) | ||||
|       self._is_tracing = True | ||||
|       self._enable_trace() | ||||
| 
 | ||||
|   def on_test_begin(self, logs=None): | ||||
|     self._set_default_writer(self._validation_run_name) | ||||
| @ -1976,14 +1975,14 @@ class TensorBoard(Callback): | ||||
|       self._log_metrics(logs, prefix='batch_', step=train_batches) | ||||
| 
 | ||||
|     self._increment_step(self._train_run_name) | ||||
| 
 | ||||
|     if context.executing_eagerly(): | ||||
|       if self._is_tracing and math_ops.greater_equal(train_batches, | ||||
|                                                      self._stop_batch): | ||||
|         self._log_trace() | ||||
|       elif (not self._is_tracing and | ||||
|             math_ops.equal(train_batches, self._start_batch - 1)): | ||||
|         self._enable_trace() | ||||
|     if self._is_tracing: | ||||
|       control_flow_ops.cond( | ||||
|           math_ops.greater_equal(train_batches, self._stop_batch), | ||||
|           lambda: self._log_trace_return_true(), lambda: False)  # pylint: disable=unnecessary-lambda | ||||
|     else: | ||||
|       control_flow_ops.cond( | ||||
|           math_ops.equal(train_batches, self._start_batch - 1), | ||||
|           lambda: self._enable_trace_return_true(), lambda: False)  # pylint: disable=unnecessary-lambda | ||||
| 
 | ||||
|   def on_test_batch_end(self, batch, logs=None): | ||||
|     if self.update_freq == 'epoch': | ||||
| @ -2020,22 +2019,49 @@ class TensorBoard(Callback): | ||||
|           self.log_dir, self.model._get_distribution_strategy())  # pylint: disable=protected-access | ||||
| 
 | ||||
|   def _enable_trace(self): | ||||
|     """Starts to collect trace graph to TensorBoard. | ||||
| 
 | ||||
|     Collects both trace and graph in eager mode, and trace only in graph mode. | ||||
|     """ | ||||
|     if context.executing_eagerly(): | ||||
|       # Graph must be traced in eager mode. | ||||
|       summary_ops_v2.trace_on(graph=True, profiler=False) | ||||
|     profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) | ||||
|     self._is_tracing = True | ||||
| 
 | ||||
|   def _enable_trace_return_true(self): | ||||
|     """Starts to collect trace graph to TensorBoard and returns True. | ||||
| 
 | ||||
|     Returns: | ||||
|       True. | ||||
|     """ | ||||
|     self._enable_trace() | ||||
|     return True | ||||
| 
 | ||||
|   def _log_trace(self): | ||||
|     """Logs the trace graph to TensorBoard.""" | ||||
|     """Logs the trace graph to TensorBoard. | ||||
| 
 | ||||
|     Logs both trace and graph in eager mode, and trace only in graph mode. | ||||
|     """ | ||||
|     profiler.stop() | ||||
|     if context.executing_eagerly(): | ||||
|       # Graph must be traced in eager mode. | ||||
|       with self._get_writer(self._train_run_name).as_default(), \ | ||||
|           summary_ops_v2.always_record_summaries(): | ||||
|         # TODO(b/126388999): Remove step info in the summary name. | ||||
|         step = K.get_value(self._total_batches_seen[self._train_run_name]) | ||||
|         summary_ops_v2.trace_export(name='batch_%d' % step, step=step) | ||||
|         profiler.stop() | ||||
|     self._is_tracing = False | ||||
| 
 | ||||
|   def _log_trace_return_true(self): | ||||
|     """Logs the trace graph to TensorBoard and returns True. | ||||
| 
 | ||||
|     Returns: | ||||
|       True. | ||||
|     """ | ||||
|     self._log_trace() | ||||
|     return True | ||||
| 
 | ||||
|   def _log_metrics(self, logs, prefix, step): | ||||
|     """Writes metrics out as custom scalar summaries. | ||||
| 
 | ||||
|  | ||||
| @ -36,6 +36,7 @@ from tensorflow.core.framework import summary_pb2 | ||||
| from tensorflow.python import keras | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.framework import random_seed | ||||
| from tensorflow.python.keras import keras_parameterized | ||||
| from tensorflow.python.keras import testing_utils | ||||
| @ -2065,6 +2066,31 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): | ||||
|     ) | ||||
|     self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) | ||||
| 
 | ||||
|   # Test case that replicates a Github issue. | ||||
|   # https://github.com/tensorflow/tensorflow/issues/37543 | ||||
|   def test_TensorBoard_autoTrace_profileTwiceGraphMode(self): | ||||
|     ops.disable_eager_execution() | ||||
|     inp = keras.Input((1,)) | ||||
|     out = keras.layers.Dense(units=1)(inp) | ||||
|     model = keras.Model(inp, out) | ||||
| 
 | ||||
|     model.compile(gradient_descent.SGD(1), 'mse') | ||||
| 
 | ||||
|     model.fit( | ||||
|         np.zeros((64, 1)), | ||||
|         np.zeros((64, 1)), | ||||
|         callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=1)], | ||||
|     ) | ||||
|     # Verifies trace exists in the first train_dir. | ||||
|     self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) | ||||
|     model.fit( | ||||
|         np.zeros((64, 1)), | ||||
|         np.zeros((64, 1)), | ||||
|         callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=2)], | ||||
|     ) | ||||
|     # Verifies trace exists in the second train_dir. | ||||
|     self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) | ||||
| 
 | ||||
|   def test_TensorBoard_autoTrace_profileBatchRange(self): | ||||
|     model = self._get_seq_model() | ||||
|     x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user