Reset the global_train_batch in each training.

PiperOrigin-RevId: 303016514
Change-Id: I6af560f7f6e94c359600c2913a9dd426f062b921
This commit is contained in:
A. Unique TensorFlower 2020-03-25 18:35:39 -07:00 committed by TensorFlower Gardener
parent 2eb1429580
commit c94d33c7d3
2 changed files with 36 additions and 10 deletions

View File

@ -2005,6 +2005,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0)
def on_train_begin(self, logs=None):
self._global_train_batch = 0
self._push_writer(self._train_writer, self._train_step)
def on_train_end(self, logs=None):

View File

@ -2018,14 +2018,16 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
run_eagerly=testing_utils.should_run_eagerly())
return model
def _get_trace_file(self, logdir):
def _count_trace_file(self, logdir):
profile_dir = os.path.join(logdir, 'plugins', 'profile')
count = 0
for (dirpath, dirnames, filenames) in os.walk(profile_dir):
del dirpath # unused
del dirnames # unused
for filename in filenames:
if filename.endswith('.trace.json.gz'):
return os.path.join(dirpath, filename)
return None
count += 1
return count
def fitModelAndAssertKerasModelWritten(self, model):
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
@ -2095,7 +2097,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
_ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
},
)
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
model = self._get_seq_model()
@ -2118,7 +2120,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
},
)
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_profileBatchRangeSingle(self):
model = self._get_seq_model()
@ -2142,7 +2144,30 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
},
)
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_profileBatchRangeTwice(self):
model = self._get_seq_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch='10,10', write_graph=False)
model.fit(
x,
y,
batch_size=3,
epochs=10,
validation_data=(x, y),
callbacks=[tb_cbk])
model.fit(
x,
y,
batch_size=3,
epochs=10,
validation_data=(x, y),
callbacks=[tb_cbk])
self.assertEqual(2, self._count_trace_file(logdir=self.train_dir))
# Test case that replicates a Github issue.
# https://github.com/tensorflow/tensorflow/issues/37543
@ -2162,7 +2187,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=1)],
)
# Verifies trace exists in the first logdir.
self.assertIsNotNone(self._get_trace_file(logdir=logdir))
self.assertEqual(1, self._count_trace_file(logdir=logdir))
logdir = os.path.join(self.get_temp_dir(), 'tb2')
model.fit(
np.zeros((64, 1)),
@ -2171,7 +2196,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=2)],
)
# Verifies trace exists in the second logdir.
self.assertIsNotNone(self._get_trace_file(logdir=logdir))
self.assertEqual(1, self._count_trace_file(logdir=logdir))
def test_TensorBoard_autoTrace_profileBatchRange(self):
model = self._get_seq_model()
@ -2195,7 +2220,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
_ObservedSummary(logdir=self.train_dir, tag=u'batch_3'),
},
)
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_profileInvalidBatchRange(self):
with self.assertRaises(ValueError):
@ -2237,7 +2262,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
# Enabled trace only on the 10000th batch, thus it should be empty.
self.assertEmpty(summary_file.tensors)
self.assertIsNone(self._get_trace_file(logdir=self.train_dir))
self.assertEqual(0, self._count_trace_file(logdir=self.train_dir))
class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase):