Reset the global_train_batch in each training.
PiperOrigin-RevId: 303016514 Change-Id: I6af560f7f6e94c359600c2913a9dd426f062b921
This commit is contained in:
parent
2eb1429580
commit
c94d33c7d3
@ -2005,6 +2005,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
|||||||
self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0)
|
self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0)
|
||||||
|
|
||||||
def on_train_begin(self, logs=None):
|
def on_train_begin(self, logs=None):
|
||||||
|
self._global_train_batch = 0
|
||||||
self._push_writer(self._train_writer, self._train_step)
|
self._push_writer(self._train_writer, self._train_step)
|
||||||
|
|
||||||
def on_train_end(self, logs=None):
|
def on_train_end(self, logs=None):
|
||||||
|
@ -2018,14 +2018,16 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
|||||||
run_eagerly=testing_utils.should_run_eagerly())
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _get_trace_file(self, logdir):
|
def _count_trace_file(self, logdir):
|
||||||
profile_dir = os.path.join(logdir, 'plugins', 'profile')
|
profile_dir = os.path.join(logdir, 'plugins', 'profile')
|
||||||
|
count = 0
|
||||||
for (dirpath, dirnames, filenames) in os.walk(profile_dir):
|
for (dirpath, dirnames, filenames) in os.walk(profile_dir):
|
||||||
|
del dirpath # unused
|
||||||
del dirnames # unused
|
del dirnames # unused
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
if filename.endswith('.trace.json.gz'):
|
if filename.endswith('.trace.json.gz'):
|
||||||
return os.path.join(dirpath, filename)
|
count += 1
|
||||||
return None
|
return count
|
||||||
|
|
||||||
def fitModelAndAssertKerasModelWritten(self, model):
|
def fitModelAndAssertKerasModelWritten(self, model):
|
||||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||||
@ -2095,7 +2097,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
|||||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
|
_ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
|
||||||
|
|
||||||
def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
|
def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
|
||||||
model = self._get_seq_model()
|
model = self._get_seq_model()
|
||||||
@ -2118,7 +2120,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
|||||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
|
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
|
||||||
|
|
||||||
def test_TensorBoard_autoTrace_profileBatchRangeSingle(self):
|
def test_TensorBoard_autoTrace_profileBatchRangeSingle(self):
|
||||||
model = self._get_seq_model()
|
model = self._get_seq_model()
|
||||||
@ -2142,7 +2144,30 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
|||||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
|
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
|
||||||
|
|
||||||
|
def test_TensorBoard_autoTrace_profileBatchRangeTwice(self):
|
||||||
|
model = self._get_seq_model()
|
||||||
|
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||||
|
tb_cbk = keras.callbacks.TensorBoard(
|
||||||
|
self.logdir, histogram_freq=1, profile_batch='10,10', write_graph=False)
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
batch_size=3,
|
||||||
|
epochs=10,
|
||||||
|
validation_data=(x, y),
|
||||||
|
callbacks=[tb_cbk])
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
batch_size=3,
|
||||||
|
epochs=10,
|
||||||
|
validation_data=(x, y),
|
||||||
|
callbacks=[tb_cbk])
|
||||||
|
self.assertEqual(2, self._count_trace_file(logdir=self.train_dir))
|
||||||
|
|
||||||
# Test case that replicates a Github issue.
|
# Test case that replicates a Github issue.
|
||||||
# https://github.com/tensorflow/tensorflow/issues/37543
|
# https://github.com/tensorflow/tensorflow/issues/37543
|
||||||
@ -2162,7 +2187,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
|||||||
callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=1)],
|
callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=1)],
|
||||||
)
|
)
|
||||||
# Verifies trace exists in the first logdir.
|
# Verifies trace exists in the first logdir.
|
||||||
self.assertIsNotNone(self._get_trace_file(logdir=logdir))
|
self.assertEqual(1, self._count_trace_file(logdir=logdir))
|
||||||
logdir = os.path.join(self.get_temp_dir(), 'tb2')
|
logdir = os.path.join(self.get_temp_dir(), 'tb2')
|
||||||
model.fit(
|
model.fit(
|
||||||
np.zeros((64, 1)),
|
np.zeros((64, 1)),
|
||||||
@ -2171,7 +2196,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
|||||||
callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=2)],
|
callbacks=[keras.callbacks.TensorBoard(logdir, profile_batch=2)],
|
||||||
)
|
)
|
||||||
# Verifies trace exists in the second logdir.
|
# Verifies trace exists in the second logdir.
|
||||||
self.assertIsNotNone(self._get_trace_file(logdir=logdir))
|
self.assertEqual(1, self._count_trace_file(logdir=logdir))
|
||||||
|
|
||||||
def test_TensorBoard_autoTrace_profileBatchRange(self):
|
def test_TensorBoard_autoTrace_profileBatchRange(self):
|
||||||
model = self._get_seq_model()
|
model = self._get_seq_model()
|
||||||
@ -2195,7 +2220,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
|||||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_3'),
|
_ObservedSummary(logdir=self.train_dir, tag=u'batch_3'),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
self.assertEqual(1, self._count_trace_file(logdir=self.train_dir))
|
||||||
|
|
||||||
def test_TensorBoard_autoTrace_profileInvalidBatchRange(self):
|
def test_TensorBoard_autoTrace_profileInvalidBatchRange(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -2237,7 +2262,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
# Enabled trace only on the 10000th batch, thus it should be empty.
|
# Enabled trace only on the 10000th batch, thus it should be empty.
|
||||||
self.assertEmpty(summary_file.tensors)
|
self.assertEmpty(summary_file.tensors)
|
||||||
self.assertIsNone(self._get_trace_file(logdir=self.train_dir))
|
self.assertEqual(0, self._count_trace_file(logdir=self.train_dir))
|
||||||
|
|
||||||
|
|
||||||
class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase):
|
class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user