Reset the global_train_batch in each training.
PiperOrigin-RevId: 303016514 Change-Id: I6af560f7f6e94c359600c2913a9dd426f062b921
This commit is contained in:
parent
2eb1429580
commit
c94d33c7d3
@ -2005,6 +2005,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0)
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
self._global_train_batch = 0
|
||||
self._push_writer(self._train_writer, self._train_step)
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
|
@ -2018,14 +2018,16 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
return model
|
||||
|
||||
def _get_trace_file(self, logdir):
|
||||
def _count_trace_file(self, logdir):
|
||||
profile_dir = os.path.join(logdir, 'plugins', 'profile')
|
||||
count = 0
|
||||
for (dirpath, dirnames, filenames) in os.walk(profile_dir):
|
||||
del dirpath # unused
|
||||
del dirnames # unused
|
||||
for filename in filenames:
|
||||
if filename.endswith('.trace.json.gz'):
|
||||
return os.path.join(dirpath, filename)
|
||||
return None
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def fitModelAndAssertKerasModelWritten(self, model):
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
@ -2095,7 +2097,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
|
||||
},
|
||||
)
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
|
||||
|
||||
def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
|
||||
model = self._get_seq_model()
|
||||
@ -2118,7 +2120,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
|
||||
},
|
||||
)
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
|
||||
|
||||
def test_TensorBoard_autoTrace_profileBatchRangeSingle(self):
|
||||
model = self._get_seq_model()
|
||||
@ -2142,7 +2144,30 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
|
||||
},
|
||||
)
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
|
||||
|
||||
def test_TensorBoard_autoTrace_profileBatchRangeTwice(self):
|
||||
model = self._get_seq_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = keras.callbacks.TensorBoard(
|
||||
self.logdir, histogram_freq=1, profile_batch='10,10', write_graph=False)
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=3,
|
||||
epochs=10,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk])
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=3,
|
||||
epochs=10,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk])
|
||||
self.assertEqual(2, self._count_trace_file(logdir=self.train_dir))
|
||||
|
||||
# Test case that replicates a Github issue.
|
||||
# https://github.com/tensorflow/tensorflow/issues/37543
|
||||
@ -2162,7 +2187,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=1)],
|
||||
)
|
||||
# Verifies trace exists in the first logdir.
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=logdir))
|
||||
self.assertEqual(1, self._count_trace_file(logdir=logdir))
|
||||
logdir = os.path.join(self.get_temp_dir(), 'tb2')
|
||||
model.fit(
|
||||
np.zeros((64, 1)),
|
||||
@ -2171,7 +2196,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=2)],
|
||||
)
|
||||
# Verifies trace exists in the second logdir.
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=logdir))
|
||||
self.assertEqual(1, self._count_trace_file(logdir=logdir))
|
||||
|
||||
def test_TensorBoard_autoTrace_profileBatchRange(self):
|
||||
model = self._get_seq_model()
|
||||
@ -2195,7 +2220,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_3'),
|
||||
},
|
||||
)
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
|
||||
|
||||
def test_TensorBoard_autoTrace_profileInvalidBatchRange(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -2237,7 +2262,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
|
||||
# Enabled trace only on the 10000th batch, thus it should be empty.
|
||||
self.assertEmpty(summary_file.tensors)
|
||||
self.assertIsNone(self._get_trace_file(logdir=self.train_dir))
|
||||
self.assertEqual(0, self._count_trace_file(logdir=self.train_dir))
|
||||
|
||||
|
||||
class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user