Return a iterator from summary_iterator instead of using a generator, to allow reusing the iterator after end of file.

PiperOrigin-RevId: 321401390
Change-Id: I2d08d6312cead7f97fb572360a631f8c8754d418
This commit is contained in:
Michael Banfield 2020-07-15 11:24:33 -07:00 committed by TensorFlower Gardener
parent 4c7d80b96a
commit ac988f3bb8
3 changed files with 94 additions and 6 deletions

View File

@ -7275,6 +7275,7 @@ py_tests(
size = "small", size = "small",
srcs = [ srcs = [
"summary/plugin_asset_test.py", "summary/plugin_asset_test.py",
"summary/summary_iterator_test.py",
"summary/summary_test.py", "summary/summary_test.py",
"summary/writer/writer_test.py", "summary/writer/writer_test.py",
], ],

View File

@ -24,10 +24,26 @@ from tensorflow.python.lib.io import tf_record
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
class _SummaryIterator(object):
"""Yields `Event` protocol buffers from a given path."""
def __init__(self, path):
self._tf_record_iterator = tf_record.tf_record_iterator(path)
def __iter__(self):
return self
def __next__(self):
r = next(self._tf_record_iterator)
return event_pb2.Event.FromString(r)
next = __next__
@tf_export(v1=['train.summary_iterator']) @tf_export(v1=['train.summary_iterator'])
def summary_iterator(path): def summary_iterator(path):
# pylint: disable=line-too-long # pylint: disable=line-too-long
"""An iterator for reading `Event` protocol buffers from an event file. """Returns a iterator for reading `Event` protocol buffers from an event file.
You can use this function to read events written to an event file. It returns You can use this function to read events written to an event file. It returns
a Python iterator that yields `Event` protocol buffers. a Python iterator that yields `Event` protocol buffers.
@ -51,6 +67,18 @@ def summary_iterator(path):
if v.tag == 'loss': if v.tag == 'loss':
print(v.simple_value) print(v.simple_value)
``` ```
Example: Continuously check for new summary values.
```python
summaries = tf.compat.v1.train.summary_iterator(path to events file)
while True:
for e in summaries:
for v in e.summary.value:
if v.tag == 'loss':
print(v.simple_value)
# Wait for a bit before checking the file for any new events
time.sleep(wait time)
```
See the protocol buffer definitions of See the protocol buffer definitions of
[Event](https://www.tensorflow.org/code/tensorflow/core/util/event.proto) [Event](https://www.tensorflow.org/code/tensorflow/core/util/event.proto)
@ -61,9 +89,7 @@ def summary_iterator(path):
Args: Args:
path: The path to an event file created by a `SummaryWriter`. path: The path to an event file created by a `SummaryWriter`.
Yields: Returns:
`Event` protocol buffers. A iterator that yields `Event` protocol buffers
""" """
# pylint: enable=line-too-long return _SummaryIterator(path)
for r in tf_record.tf_record_iterator(path):
yield event_pb2.Event.FromString(r)

View File

@ -0,0 +1,61 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.summary.summary_iterator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os.path
from tensorflow.core.util import event_pb2
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer
class SummaryIteratorTestCase(test.TestCase):
@test_util.run_deprecated_v1
def testSummaryIteratorEventsAddedAfterEndOfFile(self):
test_dir = os.path.join(self.get_temp_dir(), "events")
with writer.FileWriter(test_dir) as w:
session_log_start = event_pb2.SessionLog.START
w.add_session_log(event_pb2.SessionLog(status=session_log_start), 1)
w.flush()
path = glob.glob(os.path.join(test_dir, "event*"))[0]
rr = summary_iterator.summary_iterator(path)
# The first event should list the file_version.
ev = next(rr)
self.assertEqual("brain.Event:2", ev.file_version)
# The next event should be the START message.
ev = next(rr)
self.assertEqual(1, ev.step)
self.assertEqual(session_log_start, ev.session_log.status)
# Reached EOF.
self.assertRaises(StopIteration, lambda: next(rr))
w.add_session_log(event_pb2.SessionLog(status=session_log_start), 2)
w.flush()
# The new event is read, after previously seeing EOF.
ev = next(rr)
self.assertEqual(2, ev.step)
self.assertEqual(session_log_start, ev.session_log.status)
# Get EOF again.
self.assertRaises(StopIteration, lambda: next(rr))
if __name__ == "__main__":
test.main()