Merge pull request #16546 from borispf/est-raw-summary-py3

Fix raw summary metrics for Estimators in python 3
This commit is contained in:
Martin Wicke 2018-02-15 17:19:52 -08:00 committed by GitHub
commit 72f215bad3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 8 deletions

View File

@ -1114,7 +1114,7 @@ def _write_dict_to_summary(output_dir,
isinstance(dictionary[key], np.int32) or
isinstance(dictionary[key], int)):
summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
elif isinstance(dictionary[key], six.string_types):
elif isinstance(dictionary[key], six.binary_type):
try:
summ = summary_pb2.Summary.FromString(dictionary[key])
for i, _ in enumerate(summ.value):

View File

@ -80,18 +80,18 @@ def dummy_model_fn(features, labels, params):
_, _, _ = features, labels, params
def check_eventfile_for_keyword(keyword, est):
def check_eventfile_for_keyword(keyword, dir_):
"""Checks event files for the keyword."""
writer_cache.FileWriterCache.clear()
# Get last Event written.
event_paths = glob.glob(os.path.join(est.model_dir, 'events*'))
event_paths = glob.glob(os.path.join(dir_, 'events*'))
last_event = None
for last_event in summary_iterator.summary_iterator(event_paths[-1]):
if last_event.summary is not None:
if last_event.summary.value:
if keyword in last_event.summary.value[0].tag:
for value in last_event.summary.value:
if keyword in value.tag:
return True
return False
@ -610,7 +610,7 @@ class EstimatorTrainTest(test.TestCase):
# Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear()
if check_eventfile_for_keyword('loss', est):
if check_eventfile_for_keyword('loss', est.model_dir):
return
self.fail('{} should be part of reported summaries.'.format('loss'))
@ -1290,8 +1290,9 @@ class EstimatorEvaluateTest(test.TestCase):
# Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear()
# Get last Event written.
if check_eventfile_for_keyword('image', est):
# Get last evaluation Event written.
if check_eventfile_for_keyword('image',
os.path.join(est.model_dir, 'eval')):
return
self.fail('{} should be part of reported summaries.'.format('image'))