Merge pull request #16546 from borispf/est-raw-summary-py3

Fix raw summary metrics for Estimators in python 3
This commit is contained in:
Martin Wicke 2018-02-15 17:19:52 -08:00 committed by GitHub
commit 72f215bad3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 8 deletions

View File

@ -1114,7 +1114,7 @@ def _write_dict_to_summary(output_dir,
isinstance(dictionary[key], np.int32) or isinstance(dictionary[key], np.int32) or
isinstance(dictionary[key], int)): isinstance(dictionary[key], int)):
summary_proto.value.add(tag=key, simple_value=int(dictionary[key])) summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
elif isinstance(dictionary[key], six.string_types): elif isinstance(dictionary[key], six.binary_type):
try: try:
summ = summary_pb2.Summary.FromString(dictionary[key]) summ = summary_pb2.Summary.FromString(dictionary[key])
for i, _ in enumerate(summ.value): for i, _ in enumerate(summ.value):

View File

@ -80,18 +80,18 @@ def dummy_model_fn(features, labels, params):
_, _, _ = features, labels, params _, _, _ = features, labels, params
def check_eventfile_for_keyword(keyword, est): def check_eventfile_for_keyword(keyword, dir_):
"""Checks event files for the keyword.""" """Checks event files for the keyword."""
writer_cache.FileWriterCache.clear() writer_cache.FileWriterCache.clear()
# Get last Event written. # Get last Event written.
event_paths = glob.glob(os.path.join(est.model_dir, 'events*')) event_paths = glob.glob(os.path.join(dir_, 'events*'))
last_event = None last_event = None
for last_event in summary_iterator.summary_iterator(event_paths[-1]): for last_event in summary_iterator.summary_iterator(event_paths[-1]):
if last_event.summary is not None: if last_event.summary is not None:
if last_event.summary.value: for value in last_event.summary.value:
if keyword in last_event.summary.value[0].tag: if keyword in value.tag:
return True return True
return False return False
@ -610,7 +610,7 @@ class EstimatorTrainTest(test.TestCase):
# Make sure nothing is stuck in limbo. # Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear() writer_cache.FileWriterCache.clear()
if check_eventfile_for_keyword('loss', est): if check_eventfile_for_keyword('loss', est.model_dir):
return return
self.fail('{} should be part of reported summaries.'.format('loss')) self.fail('{} should be part of reported summaries.'.format('loss'))
@ -1290,8 +1290,9 @@ class EstimatorEvaluateTest(test.TestCase):
# Make sure nothing is stuck in limbo. # Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear() writer_cache.FileWriterCache.clear()
# Get last Event written. # Get last evaluation Event written.
if check_eventfile_for_keyword('image', est): if check_eventfile_for_keyword('image',
os.path.join(est.model_dir, 'eval')):
return return
self.fail('{} should be part of reported summaries.'.format('image')) self.fail('{} should be part of reported summaries.'.format('image'))