Adds ability to set a "family" attribute in Tensorflow summaries, which
controls the "tab name" of the summary that is displayed. This solution keeps using name_scope to keep names unique, but then prefixes the tag with the family name if provided. PiperOrigin-RevId: 158278922
This commit is contained in:
parent
611c82b5be
commit
7b5302af0a
@ -35,6 +35,14 @@ def _Collect(val, collections, default_collections):
|
||||
ops.add_to_collection(key, val)
|
||||
|
||||
|
||||
# TODO(dandelion): As currently implemented, this op has several problems.
|
||||
# The 'summary_description' field is passed but not used by the kernel.
|
||||
# The 'name' field is used to creat a scope and passed down via name=scope,
|
||||
# but gen_logging_ops._tensor_summary ignores this parameter and uses the
|
||||
# kernel's op name as the name. This is ok because scope and the op name
|
||||
# are identical, but it's probably worthwhile to fix.
|
||||
# Finally, because of the complications above, this currently does not
|
||||
# support the family= attribute added to other summaries in cl/156791589.
|
||||
def tensor_summary( # pylint: disable=invalid-name
|
||||
name,
|
||||
tensor,
|
||||
|
@ -36,6 +36,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib as _contextlib
|
||||
import re as _re
|
||||
|
||||
from google.protobuf import json_format as _json_format
|
||||
@ -104,7 +105,46 @@ def _clean_tag(name):
|
||||
return name
|
||||
|
||||
|
||||
def scalar(name, tensor, collections=None):
|
||||
@_contextlib.contextmanager
|
||||
def _summary_scope(name, family=None, default_name=None, values=None):
|
||||
"""Enters a scope used for the summary and yields both the name and tag.
|
||||
|
||||
To ensure that the summary tag name is always unique, we create a name scope
|
||||
based on `name` and use the full scope name in the tag.
|
||||
|
||||
If `family` is set, then the tag name will be '<family>/<scope_name>', where
|
||||
`scope_name` is `<outer_scope>/<family>/<name>`. This ensures that `family`
|
||||
is always the prefix of the tag (and unmodified), while ensuring the scope
|
||||
respects the outer scope from this this summary was created.
|
||||
|
||||
Args:
|
||||
name: A name for the generated summary node.
|
||||
family: Optional; if provided, used as the prefix of the summary tag name.
|
||||
default_name: Optional; if provided, used as default name of the summary.
|
||||
values: Optional; passed as `values` parameter to name_scope.
|
||||
|
||||
Yields:
|
||||
A tuple `(tag, scope)`, both of which are unique and should be used for the
|
||||
tag and the scope for the summary to output.
|
||||
"""
|
||||
name = _clean_tag(name)
|
||||
family = _clean_tag(family)
|
||||
# Use family name in the scope to ensure uniqueness of scope/tag.
|
||||
scope_base_name = name if family is None else '{}/{}'.format(family, name)
|
||||
with _ops.name_scope(scope_base_name, default_name, values=values) as scope:
|
||||
if family is None:
|
||||
tag = scope.rstrip('/')
|
||||
else:
|
||||
# Prefix our scope with family again so it displays in the right tab.
|
||||
tag = '{}/{}'.format(family, scope.rstrip('/'))
|
||||
# Note: tag is not 100% unique if the user explicitly enters a scope with
|
||||
# the same name as family, then later enter it again before summaries.
|
||||
# This is very contrived though, and we opt here to let it be a runtime
|
||||
# exception if tags do indeed collide.
|
||||
yield (tag, scope)
|
||||
|
||||
|
||||
def scalar(name, tensor, collections=None, family=None):
|
||||
"""Outputs a `Summary` protocol buffer containing a single scalar value.
|
||||
|
||||
The generated Summary has a Tensor.proto containing the input Tensor.
|
||||
@ -115,6 +155,8 @@ def scalar(name, tensor, collections=None):
|
||||
tensor: A real numeric Tensor containing a single value.
|
||||
collections: Optional list of graph collections keys. The new summary op is
|
||||
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
|
||||
family: Optional; if provided, used as the prefix of the summary tag name,
|
||||
which controls the tab name used for display on Tensorboard.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
|
||||
@ -122,16 +164,14 @@ def scalar(name, tensor, collections=None):
|
||||
Raises:
|
||||
ValueError: If tensor has the wrong shape or type.
|
||||
"""
|
||||
name = _clean_tag(name)
|
||||
with _ops.name_scope(name, None, [tensor]) as scope:
|
||||
with _summary_scope(name, family, values=[tensor]) as (tag, scope):
|
||||
# pylint: disable=protected-access
|
||||
val = _gen_logging_ops._scalar_summary(
|
||||
tags=scope.rstrip('/'), values=tensor, name=scope)
|
||||
val = _gen_logging_ops._scalar_summary(tags=tag, values=tensor, name=scope)
|
||||
_collect(val, collections, [_ops.GraphKeys.SUMMARIES])
|
||||
return val
|
||||
|
||||
|
||||
def image(name, tensor, max_outputs=3, collections=None):
|
||||
def image(name, tensor, max_outputs=3, collections=None, family=None):
|
||||
"""Outputs a `Summary` protocol buffer with images.
|
||||
|
||||
The summary has up to `max_outputs` summary values containing images. The
|
||||
@ -169,24 +209,22 @@ def image(name, tensor, max_outputs=3, collections=None):
|
||||
max_outputs: Max number of batch elements to generate images for.
|
||||
collections: Optional list of ops.GraphKeys. The collections to add the
|
||||
summary to. Defaults to [_ops.GraphKeys.SUMMARIES]
|
||||
family: Optional; if provided, used as the prefix of the summary tag name,
|
||||
which controls the tab name used for display on Tensorboard.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||||
buffer.
|
||||
"""
|
||||
name = _clean_tag(name)
|
||||
with _ops.name_scope(name, None, [tensor]) as scope:
|
||||
with _summary_scope(name, family, values=[tensor]) as (tag, scope):
|
||||
# pylint: disable=protected-access
|
||||
val = _gen_logging_ops._image_summary(
|
||||
tag=scope.rstrip('/'),
|
||||
tensor=tensor,
|
||||
max_images=max_outputs,
|
||||
name=scope)
|
||||
tag=tag, tensor=tensor, max_images=max_outputs, name=scope)
|
||||
_collect(val, collections, [_ops.GraphKeys.SUMMARIES])
|
||||
return val
|
||||
|
||||
|
||||
def histogram(name, values, collections=None):
|
||||
def histogram(name, values, collections=None, family=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Outputs a `Summary` protocol buffer with a histogram.
|
||||
|
||||
@ -208,22 +246,24 @@ def histogram(name, values, collections=None):
|
||||
build the histogram.
|
||||
collections: Optional list of graph collections keys. The new summary op is
|
||||
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
|
||||
family: Optional; if provided, used as the prefix of the summary tag name,
|
||||
which controls the tab name used for display on Tensorboard.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||||
buffer.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
name = _clean_tag(name)
|
||||
with _ops.name_scope(name, 'HistogramSummary', [values]) as scope:
|
||||
with _summary_scope(name, family, values=[values],
|
||||
default_name='HistogramSummary') as (tag, scope):
|
||||
# pylint: disable=protected-access
|
||||
val = _gen_logging_ops._histogram_summary(
|
||||
tag=scope.rstrip('/'), values=values, name=scope)
|
||||
tag=tag, values=values, name=scope)
|
||||
_collect(val, collections, [_ops.GraphKeys.SUMMARIES])
|
||||
return val
|
||||
|
||||
|
||||
def audio(name, tensor, sample_rate, max_outputs=3, collections=None):
|
||||
def audio(name, tensor, sample_rate, max_outputs=3, collections=None,
|
||||
family=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Outputs a `Summary` protocol buffer with audio.
|
||||
|
||||
@ -250,23 +290,20 @@ def audio(name, tensor, sample_rate, max_outputs=3, collections=None):
|
||||
max_outputs: Max number of batch elements to generate audio for.
|
||||
collections: Optional list of ops.GraphKeys. The collections to add the
|
||||
summary to. Defaults to [_ops.GraphKeys.SUMMARIES]
|
||||
family: Optional; if provided, used as the prefix of the summary tag name,
|
||||
which controls the tab name used for display on Tensorboard.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
|
||||
buffer.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
name = _clean_tag(name)
|
||||
with _ops.name_scope(name, None, [tensor]) as scope:
|
||||
with _summary_scope(name, family=family, values=[tensor]) as (tag, scope):
|
||||
# pylint: disable=protected-access
|
||||
sample_rate = _ops.convert_to_tensor(
|
||||
sample_rate, dtype=_dtypes.float32, name='sample_rate')
|
||||
val = _gen_logging_ops._audio_summary_v2(
|
||||
tag=scope.rstrip('/'),
|
||||
tensor=tensor,
|
||||
max_outputs=max_outputs,
|
||||
sample_rate=sample_rate,
|
||||
name=scope)
|
||||
tag=tag, tensor=tensor, max_outputs=max_outputs,
|
||||
sample_rate=sample_rate, name=scope)
|
||||
_collect(val, collections, [_ops.GraphKeys.SUMMARIES])
|
||||
return val
|
||||
|
||||
|
@ -19,11 +19,9 @@ from __future__ import print_function
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from google.protobuf import json_format
|
||||
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -46,6 +44,29 @@ class ScalarSummaryTest(test.TestCase):
|
||||
self.assertEqual(values[0].tag, 'outer/inner')
|
||||
self.assertEqual(values[0].simple_value, 3.0)
|
||||
|
||||
def testScalarSummaryWithFamily(self):
|
||||
with self.test_session() as s:
|
||||
i = constant_op.constant(7)
|
||||
with ops.name_scope('outer'):
|
||||
im1 = summary_lib.scalar('inner', i, family='family')
|
||||
self.assertEquals(im1.op.name, 'outer/family/inner')
|
||||
im2 = summary_lib.scalar('inner', i, family='family')
|
||||
self.assertEquals(im2.op.name, 'outer/family/inner_1')
|
||||
sm1, sm2 = s.run([im1, im2])
|
||||
summary = summary_pb2.Summary()
|
||||
|
||||
summary.ParseFromString(sm1)
|
||||
values = summary.value
|
||||
self.assertEqual(len(values), 1)
|
||||
self.assertEqual(values[0].tag, 'family/outer/family/inner')
|
||||
self.assertEqual(values[0].simple_value, 7.0)
|
||||
|
||||
summary.ParseFromString(sm2)
|
||||
values = summary.value
|
||||
self.assertEqual(len(values), 1)
|
||||
self.assertEqual(values[0].tag, 'family/outer/family/inner_1')
|
||||
self.assertEqual(values[0].simple_value, 7.0)
|
||||
|
||||
def testSummarizingVariable(self):
|
||||
with self.test_session() as s:
|
||||
c = constant_op.constant(42.0)
|
||||
@ -75,6 +96,22 @@ class ScalarSummaryTest(test.TestCase):
|
||||
expected = sorted('outer/inner/image/{}'.format(i) for i in xrange(3))
|
||||
self.assertEqual(tags, expected)
|
||||
|
||||
def testImageSummaryWithFamily(self):
|
||||
with self.test_session() as s:
|
||||
i = array_ops.ones((5, 2, 3, 1))
|
||||
with ops.name_scope('outer'):
|
||||
im = summary_lib.image('inner', i, max_outputs=3, family='family')
|
||||
self.assertEquals(im.op.name, 'outer/family/inner')
|
||||
summary_str = s.run(im)
|
||||
summary = summary_pb2.Summary()
|
||||
summary.ParseFromString(summary_str)
|
||||
values = summary.value
|
||||
self.assertEqual(len(values), 3)
|
||||
tags = sorted(v.tag for v in values)
|
||||
expected = sorted('family/outer/family/inner/image/{}'.format(i)
|
||||
for i in xrange(3))
|
||||
self.assertEqual(tags, expected)
|
||||
|
||||
def testHistogramSummary(self):
|
||||
with self.test_session() as s:
|
||||
i = array_ops.ones((5, 4, 4, 3))
|
||||
@ -86,6 +123,48 @@ class ScalarSummaryTest(test.TestCase):
|
||||
self.assertEqual(len(summary.value), 1)
|
||||
self.assertEqual(summary.value[0].tag, 'outer/inner')
|
||||
|
||||
def testHistogramSummaryWithFamily(self):
|
||||
with self.test_session() as s:
|
||||
i = array_ops.ones((5, 4, 4, 3))
|
||||
with ops.name_scope('outer'):
|
||||
summ_op = summary_lib.histogram('inner', i, family='family')
|
||||
self.assertEquals(summ_op.op.name, 'outer/family/inner')
|
||||
summary_str = s.run(summ_op)
|
||||
summary = summary_pb2.Summary()
|
||||
summary.ParseFromString(summary_str)
|
||||
self.assertEqual(len(summary.value), 1)
|
||||
self.assertEqual(summary.value[0].tag, 'family/outer/family/inner')
|
||||
|
||||
def testAudioSummary(self):
|
||||
with self.test_session() as s:
|
||||
i = array_ops.ones((5, 3, 4))
|
||||
with ops.name_scope('outer'):
|
||||
aud = summary_lib.audio('inner', i, 0.2, max_outputs=3)
|
||||
summary_str = s.run(aud)
|
||||
summary = summary_pb2.Summary()
|
||||
summary.ParseFromString(summary_str)
|
||||
values = summary.value
|
||||
self.assertEqual(len(values), 3)
|
||||
tags = sorted(v.tag for v in values)
|
||||
expected = sorted('outer/inner/audio/{}'.format(i) for i in xrange(3))
|
||||
self.assertEqual(tags, expected)
|
||||
|
||||
def testAudioSummaryWithFamily(self):
|
||||
with self.test_session() as s:
|
||||
i = array_ops.ones((5, 3, 4))
|
||||
with ops.name_scope('outer'):
|
||||
aud = summary_lib.audio('inner', i, 0.2, max_outputs=3, family='family')
|
||||
self.assertEquals(aud.op.name, 'outer/family/inner')
|
||||
summary_str = s.run(aud)
|
||||
summary = summary_pb2.Summary()
|
||||
summary.ParseFromString(summary_str)
|
||||
values = summary.value
|
||||
self.assertEqual(len(values), 3)
|
||||
tags = sorted(v.tag for v in values)
|
||||
expected = sorted('family/outer/family/inner/audio/{}'.format(i)
|
||||
for i in xrange(3))
|
||||
self.assertEqual(tags, expected)
|
||||
|
||||
def testSummaryNameConversion(self):
|
||||
c = constant_op.constant(3)
|
||||
s = summary_lib.scalar('name with spaces', c)
|
||||
@ -97,6 +176,34 @@ class ScalarSummaryTest(test.TestCase):
|
||||
s3 = summary_lib.scalar('/name/with/leading/slash', c)
|
||||
self.assertEqual(s3.op.name, 'name/with/leading/slash')
|
||||
|
||||
def testSummaryWithFamilyMetaGraphExport(self):
|
||||
with ops.name_scope('outer'):
|
||||
i = constant_op.constant(11)
|
||||
summ = summary_lib.scalar('inner', i)
|
||||
self.assertEquals(summ.op.name, 'outer/inner')
|
||||
summ_f = summary_lib.scalar('inner', i, family='family')
|
||||
self.assertEquals(summ_f.op.name, 'outer/family/inner')
|
||||
|
||||
metagraph_def, _ = meta_graph.export_scoped_meta_graph(export_scope='outer')
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
meta_graph.import_scoped_meta_graph(metagraph_def, graph=g,
|
||||
import_scope='new_outer')
|
||||
# The summaries should exist, but with outer scope renamed.
|
||||
new_summ = g.get_tensor_by_name('new_outer/inner:0')
|
||||
new_summ_f = g.get_tensor_by_name('new_outer/family/inner:0')
|
||||
|
||||
# However, the tags are unaffected.
|
||||
with self.test_session() as s:
|
||||
new_summ_str, new_summ_f_str = s.run([new_summ, new_summ_f])
|
||||
new_summ_pb = summary_pb2.Summary()
|
||||
new_summ_pb.ParseFromString(new_summ_str)
|
||||
self.assertEquals('outer/inner', new_summ_pb.value[0].tag)
|
||||
new_summ_f_pb = summary_pb2.Summary()
|
||||
new_summ_f_pb.ParseFromString(new_summ_f_str)
|
||||
self.assertEquals('family/outer/family/inner',
|
||||
new_summ_f_pb.value[0].tag)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -30,7 +30,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "audio"
|
||||
argspec: "args=[\'name\', \'tensor\', \'sample_rate\', \'max_outputs\', \'collections\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
||||
argspec: "args=[\'name\', \'tensor\', \'sample_rate\', \'max_outputs\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'3\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_summary_description"
|
||||
@ -38,11 +38,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "histogram"
|
||||
argspec: "args=[\'name\', \'values\', \'collections\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'name\', \'values\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "image"
|
||||
argspec: "args=[\'name\', \'tensor\', \'max_outputs\', \'collections\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
||||
argspec: "args=[\'name\', \'tensor\', \'max_outputs\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'3\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "merge"
|
||||
@ -54,7 +54,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "scalar"
|
||||
argspec: "args=[\'name\', \'tensor\', \'collections\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'name\', \'tensor\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_summary"
|
||||
|
Loading…
Reference in New Issue
Block a user