Merge pull request #16546 from borispf/est-raw-summary-py3
Fix raw summary metrics for Estimators in python 3
This commit is contained in:
commit
72f215bad3
@ -1114,7 +1114,7 @@ def _write_dict_to_summary(output_dir,
|
|||||||
isinstance(dictionary[key], np.int32) or
|
isinstance(dictionary[key], np.int32) or
|
||||||
isinstance(dictionary[key], int)):
|
isinstance(dictionary[key], int)):
|
||||||
summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
|
summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
|
||||||
elif isinstance(dictionary[key], six.string_types):
|
elif isinstance(dictionary[key], six.binary_type):
|
||||||
try:
|
try:
|
||||||
summ = summary_pb2.Summary.FromString(dictionary[key])
|
summ = summary_pb2.Summary.FromString(dictionary[key])
|
||||||
for i, _ in enumerate(summ.value):
|
for i, _ in enumerate(summ.value):
|
||||||
|
@ -80,18 +80,18 @@ def dummy_model_fn(features, labels, params):
|
|||||||
_, _, _ = features, labels, params
|
_, _, _ = features, labels, params
|
||||||
|
|
||||||
|
|
||||||
def check_eventfile_for_keyword(keyword, est):
|
def check_eventfile_for_keyword(keyword, dir_):
|
||||||
"""Checks event files for the keyword."""
|
"""Checks event files for the keyword."""
|
||||||
|
|
||||||
writer_cache.FileWriterCache.clear()
|
writer_cache.FileWriterCache.clear()
|
||||||
|
|
||||||
# Get last Event written.
|
# Get last Event written.
|
||||||
event_paths = glob.glob(os.path.join(est.model_dir, 'events*'))
|
event_paths = glob.glob(os.path.join(dir_, 'events*'))
|
||||||
last_event = None
|
last_event = None
|
||||||
for last_event in summary_iterator.summary_iterator(event_paths[-1]):
|
for last_event in summary_iterator.summary_iterator(event_paths[-1]):
|
||||||
if last_event.summary is not None:
|
if last_event.summary is not None:
|
||||||
if last_event.summary.value:
|
for value in last_event.summary.value:
|
||||||
if keyword in last_event.summary.value[0].tag:
|
if keyword in value.tag:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@ -610,7 +610,7 @@ class EstimatorTrainTest(test.TestCase):
|
|||||||
# Make sure nothing is stuck in limbo.
|
# Make sure nothing is stuck in limbo.
|
||||||
writer_cache.FileWriterCache.clear()
|
writer_cache.FileWriterCache.clear()
|
||||||
|
|
||||||
if check_eventfile_for_keyword('loss', est):
|
if check_eventfile_for_keyword('loss', est.model_dir):
|
||||||
return
|
return
|
||||||
self.fail('{} should be part of reported summaries.'.format('loss'))
|
self.fail('{} should be part of reported summaries.'.format('loss'))
|
||||||
|
|
||||||
@ -1290,8 +1290,9 @@ class EstimatorEvaluateTest(test.TestCase):
|
|||||||
# Make sure nothing is stuck in limbo.
|
# Make sure nothing is stuck in limbo.
|
||||||
writer_cache.FileWriterCache.clear()
|
writer_cache.FileWriterCache.clear()
|
||||||
|
|
||||||
# Get last Event written.
|
# Get last evaluation Event written.
|
||||||
if check_eventfile_for_keyword('image', est):
|
if check_eventfile_for_keyword('image',
|
||||||
|
os.path.join(est.model_dir, 'eval')):
|
||||||
return
|
return
|
||||||
self.fail('{} should be part of reported summaries.'.format('image'))
|
self.fail('{} should be part of reported summaries.'.format('image'))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user