From c94d33c7d3fe32aa46decebe6fb261c2ff5012c3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 25 Mar 2020 18:35:39 -0700 Subject: [PATCH] Reset the global_train_batch in each training. PiperOrigin-RevId: 303016514 Change-Id: I6af560f7f6e94c359600c2913a9dd426f062b921 --- tensorflow/python/keras/callbacks.py | 1 + tensorflow/python/keras/callbacks_test.py | 45 ++++++++++++++++++----- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 7c5124e923e..1ed713b04d1 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -2005,6 +2005,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0) def on_train_begin(self, logs=None): + self._global_train_batch = 0 self._push_writer(self._train_writer, self._train_step) def on_train_end(self, logs=None): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 5de4cacfa8a..a9b1cd6a9f8 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -2018,14 +2018,16 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): run_eagerly=testing_utils.should_run_eagerly()) return model - def _get_trace_file(self, logdir): + def _count_trace_file(self, logdir): profile_dir = os.path.join(logdir, 'plugins', 'profile') + count = 0 for (dirpath, dirnames, filenames) in os.walk(profile_dir): + del dirpath # unused del dirnames # unused for filename in filenames: if filename.endswith('.trace.json.gz'): - return os.path.join(dirpath, filename) - return None + count += 1 + return count def fitModelAndAssertKerasModelWritten(self, model): x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) @@ -2095,7 +2097,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag=u'batch_1'), }, ) - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertEqual(1, self._count_trace_file(logdir=self.train_dir)) def test_TensorBoard_autoTrace_tagNameWithBatchNum(self): model = self._get_seq_model() @@ -2118,7 +2120,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'), }, ) - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertEqual(1, self._count_trace_file(logdir=self.train_dir)) def test_TensorBoard_autoTrace_profileBatchRangeSingle(self): model = self._get_seq_model() @@ -2142,7 +2144,30 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'), }, ) - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertEqual(1, self._count_trace_file(logdir=self.train_dir)) + + def test_TensorBoard_autoTrace_profileBatchRangeTwice(self): + model = self._get_seq_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + tb_cbk = keras.callbacks.TensorBoard( + self.logdir, histogram_freq=1, profile_batch='10,10', write_graph=False) + + model.fit( + x, + y, + batch_size=3, + epochs=10, + validation_data=(x, y), + callbacks=[tb_cbk]) + + model.fit( + x, + y, + batch_size=3, + epochs=10, + validation_data=(x, y), + callbacks=[tb_cbk]) + self.assertEqual(2, self._count_trace_file(logdir=self.train_dir)) # Test case that replicates a Github issue. # https://github.com/tensorflow/tensorflow/issues/37543 @@ -2162,7 +2187,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=1)], ) # Verifies trace exists in the first logdir. - self.assertIsNotNone(self._get_trace_file(logdir=logdir)) + self.assertEqual(1, self._count_trace_file(logdir=logdir)) logdir = os.path.join(self.get_temp_dir(), 'tb2') model.fit( np.zeros((64, 1)), @@ -2171,7 +2196,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=2)], ) # Verifies trace exists in the second logdir. - self.assertIsNotNone(self._get_trace_file(logdir=logdir)) + self.assertEqual(1, self._count_trace_file(logdir=logdir)) def test_TensorBoard_autoTrace_profileBatchRange(self): model = self._get_seq_model() @@ -2195,7 +2220,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag=u'batch_3'), }, ) - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertEqual(1, self._count_trace_file(logdir=self.train_dir)) def test_TensorBoard_autoTrace_profileInvalidBatchRange(self): with self.assertRaises(ValueError): @@ -2237,7 +2262,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): # Enabled trace only on the 10000th batch, thus it should be empty. self.assertEmpty(summary_file.tensors) - self.assertIsNone(self._get_trace_file(logdir=self.train_dir)) + self.assertEqual(0, self._count_trace_file(logdir=self.train_dir)) class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase):