Adds scalar summary for the centered biases.
Change: 123889304
This commit is contained in:
parent
38a3a365b3
commit
73defa5d9d
@ -21,24 +21,24 @@ from __future__ import print_function
|
||||
import functools
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import standard_ops
|
||||
|
||||
__all__ = ['summarize_tensor', 'summarize_activation', 'summarize_tensors',
|
||||
__all__ = ['assert_summary_tag_unique', 'is_summary_tag_unique',
|
||||
'summarize_tensor', 'summarize_activation', 'summarize_tensors',
|
||||
'summarize_collection', 'summarize_variables', 'summarize_weights',
|
||||
'summarize_biases', 'summarize_activations']
|
||||
'summarize_biases', 'summarize_activations',]
|
||||
|
||||
# TODO(wicke): add more unit tests for summarization functions.
|
||||
|
||||
|
||||
def _assert_summary_tag_unique(tag):
|
||||
for summary in ops.get_collection(ops.GraphKeys.SUMMARIES):
|
||||
old_tag = tensor_util.constant_value(summary.op.inputs[0])
|
||||
if tag.encode() == old_tag:
|
||||
raise ValueError('Conflict with summary tag: %s exists on summary %s %s' %
|
||||
(tag, summary, old_tag))
|
||||
def assert_summary_tag_unique(tag):
|
||||
if not is_summary_tag_unique(tag):
|
||||
raise ValueError('Conflict with summary tag: %s already exists' % tag)
|
||||
|
||||
|
||||
def _add_scalar_summary(tensor, tag=None):
|
||||
@ -56,7 +56,7 @@ def _add_scalar_summary(tensor, tag=None):
|
||||
"""
|
||||
tensor.get_shape().assert_has_rank(0)
|
||||
tag = tag or tensor.op.name
|
||||
_assert_summary_tag_unique(tag)
|
||||
assert_summary_tag_unique(tag)
|
||||
return standard_ops.scalar_summary(tag, tensor, name='%s_summary' % tag)
|
||||
|
||||
|
||||
@ -74,10 +74,26 @@ def _add_histogram_summary(tensor, tag=None):
|
||||
ValueError: If the tag is already in use.
|
||||
"""
|
||||
tag = tag or tensor.op.name
|
||||
_assert_summary_tag_unique(tag)
|
||||
assert_summary_tag_unique(tag)
|
||||
return standard_ops.histogram_summary(tag, tensor, name='%s_summary' % tag)
|
||||
|
||||
|
||||
def is_summary_tag_unique(tag):
|
||||
"""Checks if a summary tag is unique.
|
||||
|
||||
Args:
|
||||
tag: The tag to use
|
||||
|
||||
Returns:
|
||||
True if the summary tag is unique.
|
||||
"""
|
||||
existing_tags = [tensor_util.constant_value(summary.op.inputs[0])
|
||||
for summary in ops.get_collection(ops.GraphKeys.SUMMARIES)]
|
||||
existing_tags = [name.tolist() if isinstance(name, np.ndarray) else name
|
||||
for name in existing_tags]
|
||||
return tag.encode() not in existing_tags
|
||||
|
||||
|
||||
def summarize_activation(op):
|
||||
"""Summarize an activation.
|
||||
|
||||
|
@ -267,7 +267,9 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
collections=[self._centered_bias_weight_collection,
|
||||
ops.GraphKeys.VARIABLES],
|
||||
name="centered_bias_weight")
|
||||
# TODO(zakaria): Create summaries for centered_bias
|
||||
logging_ops.scalar_summary(
|
||||
["centered_bias_%d" % cb for cb in range(self._num_label_columns())],
|
||||
array_ops.reshape(centered_bias, [-1]))
|
||||
return centered_bias
|
||||
|
||||
def _centered_bias_step(self, targets, weight_tensor):
|
||||
|
@ -35,7 +35,6 @@ from tensorflow.contrib.learn.python.learn import monitors as monitors_lib
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
@ -429,14 +428,8 @@ def evaluate(graph,
|
||||
global_step_tensor = contrib_variables.assert_or_get_global_step(
|
||||
graph, global_step_tensor)
|
||||
|
||||
# Add scalar summaries for every tensor in evaluation dict if there is not
|
||||
# one existing already or it's a string.
|
||||
existing_tags = [tensor_util.constant_value(summary.op.inputs[0])
|
||||
for summary in ops.get_collection(ops.GraphKeys.SUMMARIES)]
|
||||
existing_tags = [name.tolist() if isinstance(name, np.ndarray) else name
|
||||
for name in existing_tags]
|
||||
for key, value in eval_dict.items():
|
||||
if key.encode() in existing_tags:
|
||||
if not summaries.is_summary_tag_unique(key):
|
||||
continue
|
||||
if isinstance(value, ops.Tensor):
|
||||
summaries.summarize_tensor(value, tag=key)
|
||||
|
Loading…
x
Reference in New Issue
Block a user