Update keras callbacks test to ensure it doesn't break when we switch to tf.summary.

PiperOrigin-RevId: 354123842
Change-Id: Ib3bd6c2f08f2af1cbb8045b91e9138ba52bc86a6
This commit is contained in:
Scott Zhu 2021-01-27 10:45:42 -08:00 committed by TensorFlower Gardener
parent da0aa4aa32
commit 19d4fc086b

View File

@ -1836,6 +1836,7 @@ class _SummaryFile(object):
self.histograms = set()
self.tensors = set()
self.graph_defs = []
self.convert_from_v2_summary_proto = False
def list_summaries(logdir):
@ -1883,11 +1884,17 @@ def list_summaries(logdir):
'Unexpected summary kind %r in event file %s:\n%r'
% (kind, path, event))
elif kind == 'tensor' and tag != 'keras':
# Check for V2 scalar summaries, which have a different PB
# structure.
if event.summary.value[
0].metadata.plugin_data.plugin_name == 'scalars':
container = result.scalars
# Convert the tf2 summary proto to old style for type checking.
plugin_name = value.metadata.plugin_data.plugin_name
container = {
'images': result.images,
'histograms': result.histograms,
'scalars': result.scalars,
}.get(plugin_name)
if container is not None:
result.convert_from_v2_summary_proto = True
else:
container = result.tensors
container.add(_ObservedSummary(logdir=dirpath, tag=tag))
return result
@ -2143,14 +2150,21 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
_ObservedSummary(logdir=self.train_dir, tag='kernel_0'),
},
)
if summary_file.convert_from_v2_summary_proto:
expected = {
_ObservedSummary(logdir=self.train_dir, tag='bias_0'),
_ObservedSummary(logdir=self.train_dir, tag='kernel_0'),
}
else:
expected = {
_ObservedSummary(logdir=self.train_dir, tag='bias_0/image/0'),
_ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/0'),
_ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/1'),
_ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/2'),
}
self.assertEqual(
self._strip_layer_names(summary_file.images, model_type),
{
_ObservedSummary(logdir=self.train_dir, tag='bias_0/image/0'),
_ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/0'),
_ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/1'),
_ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/2'),
},
expected
)
def test_TensorBoard_projector_callback(self):