diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index d68bb928233..42fafc5d9cc 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -7275,6 +7275,7 @@ py_tests( size = "small", srcs = [ "summary/plugin_asset_test.py", + "summary/summary_iterator_test.py", "summary/summary_test.py", "summary/writer/writer_test.py", ], diff --git a/tensorflow/python/summary/summary_iterator.py b/tensorflow/python/summary/summary_iterator.py index 5840a7a124e..35c6fa03039 100644 --- a/tensorflow/python/summary/summary_iterator.py +++ b/tensorflow/python/summary/summary_iterator.py @@ -24,10 +24,26 @@ from tensorflow.python.lib.io import tf_record from tensorflow.python.util.tf_export import tf_export +class _SummaryIterator(object): + """Yields `Event` protocol buffers from a given path.""" + + def __init__(self, path): + self._tf_record_iterator = tf_record.tf_record_iterator(path) + + def __iter__(self): + return self + + def __next__(self): + r = next(self._tf_record_iterator) + return event_pb2.Event.FromString(r) + + next = __next__ + + @tf_export(v1=['train.summary_iterator']) def summary_iterator(path): # pylint: disable=line-too-long - """An iterator for reading `Event` protocol buffers from an event file. + """Returns a iterator for reading `Event` protocol buffers from an event file. You can use this function to read events written to an event file. It returns a Python iterator that yields `Event` protocol buffers. @@ -51,6 +67,18 @@ def summary_iterator(path): if v.tag == 'loss': print(v.simple_value) ``` + Example: Continuously check for new summary values. + + ```python + summaries = tf.compat.v1.train.summary_iterator(path to events file) + while True: + for e in summaries: + for v in e.summary.value: + if v.tag == 'loss': + print(v.simple_value) + # Wait for a bit before checking the file for any new events + time.sleep(wait time) + ``` See the protocol buffer definitions of [Event](https://www.tensorflow.org/code/tensorflow/core/util/event.proto) @@ -61,9 +89,7 @@ def summary_iterator(path): Args: path: The path to an event file created by a `SummaryWriter`. - Yields: - `Event` protocol buffers. + Returns: + A iterator that yields `Event` protocol buffers """ - # pylint: enable=line-too-long - for r in tf_record.tf_record_iterator(path): - yield event_pb2.Event.FromString(r) + return _SummaryIterator(path) diff --git a/tensorflow/python/summary/summary_iterator_test.py b/tensorflow/python/summary/summary_iterator_test.py new file mode 100644 index 00000000000..d41d8d4c775 --- /dev/null +++ b/tensorflow/python/summary/summary_iterator_test.py @@ -0,0 +1,61 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.summary.summary_iterator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import os.path + +from tensorflow.core.util import event_pb2 +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test +from tensorflow.python.summary import summary_iterator +from tensorflow.python.summary.writer import writer + + +class SummaryIteratorTestCase(test.TestCase): + + @test_util.run_deprecated_v1 + def testSummaryIteratorEventsAddedAfterEndOfFile(self): + test_dir = os.path.join(self.get_temp_dir(), "events") + with writer.FileWriter(test_dir) as w: + session_log_start = event_pb2.SessionLog.START + w.add_session_log(event_pb2.SessionLog(status=session_log_start), 1) + w.flush() + path = glob.glob(os.path.join(test_dir, "event*"))[0] + rr = summary_iterator.summary_iterator(path) + # The first event should list the file_version. + ev = next(rr) + self.assertEqual("brain.Event:2", ev.file_version) + # The next event should be the START message. + ev = next(rr) + self.assertEqual(1, ev.step) + self.assertEqual(session_log_start, ev.session_log.status) + # Reached EOF. + self.assertRaises(StopIteration, lambda: next(rr)) + w.add_session_log(event_pb2.SessionLog(status=session_log_start), 2) + w.flush() + # The new event is read, after previously seeing EOF. + ev = next(rr) + self.assertEqual(2, ev.step) + self.assertEqual(session_log_start, ev.session_log.status) + # Get EOF again. + self.assertRaises(StopIteration, lambda: next(rr)) + +if __name__ == "__main__": + test.main()