diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 4df088dbafc..c0a54ca1dff 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2100,6 +2100,7 @@ py_library( deps = [ ":client", ":framework", + ":platform", ":protos_all_py", ":pywrap_tensorflow", ":summary_ops", diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index ac80b07634a..5dbde1c5477 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import gen_logging_ops as _gen_logging_ops # pylint: disable=unused-import from tensorflow.python.ops.summary_ops import tensor_summary # pylint: enable=unused-import +from tensorflow.python.platform import tf_logging as _logging from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util import compat as _compat @@ -55,6 +56,19 @@ def _collect(val, collections, default_collections): _ops.add_to_collection(key, val) +def _clean_tag(name): + # In the past, the first argument to summary ops was a tag, which allowed + # spaces. Since now we pass in the name, spaces are disallowed; to ease the + # transition and support backwards compatbility, we will convert the spaces + # to underscores (and also warn about it). + if name is not None and ' ' in name: + _logging.warning( + 'Summary tag name %s contains spaces; replacing with underscores.' % + name) + name = name.replace(' ', '_') + return name + + def scalar(name, tensor, collections=None): """Outputs a `Summary` protocol buffer containing a single scalar value. @@ -73,6 +87,7 @@ def scalar(name, tensor, collections=None): Raises: ValueError: If tensor has the wrong shape or type. """ + name = _clean_tag(name) with _ops.name_scope(name, None, [tensor]) as scope: # pylint: disable=protected-access val = _gen_logging_ops._scalar_summary( @@ -124,6 +139,7 @@ def image(name, tensor, max_outputs=3, collections=None): A scalar `Tensor` of type `string`. The serialized `Summary` protocol buffer. """ + name = _clean_tag(name) with _ops.name_scope(name, None, [tensor]) as scope: # pylint: disable=protected-access val = _gen_logging_ops._image_summary( @@ -158,6 +174,7 @@ def histogram(name, values, collections=None): buffer. """ # pylint: enable=line-too-long + name = _clean_tag(name) with _ops.name_scope(name, 'HistogramSummary', [values]) as scope: # pylint: disable=protected-access val = _gen_logging_ops._histogram_summary( @@ -199,6 +216,7 @@ def audio(name, tensor, sample_rate, max_outputs=3, collections=None): buffer. """ # pylint: enable=line-too-long + name = _clean_tag(name) with _ops.name_scope(name, None, [tensor]) as scope: # pylint: disable=protected-access sample_rate = _ops.convert_to_tensor( @@ -237,6 +255,7 @@ def merge(inputs, collections=None, name=None): buffer resulting from the merging. """ # pylint: enable=line-too-long + name = _clean_tag(name) with _ops.name_scope(name, 'Merge', inputs): # pylint: disable=protected-access val = _gen_logging_ops._merge_summary(inputs=inputs, name=name) diff --git a/tensorflow/python/summary/summary_test.py b/tensorflow/python/summary/summary_test.py index 8a01cf6bb56..85332333cff 100644 --- a/tensorflow/python/summary/summary_test.py +++ b/tensorflow/python/summary/summary_test.py @@ -65,6 +65,11 @@ class ScalarSummaryTest(tf.test.TestCase): self.assertEqual(len(summary.value), 1) self.assertEqual(summary.value[0].tag, 'outer/inner') + def testSummaryNameConversion(self): + c = tf.constant(3) + s = tf.summary.scalar('name with spaces', c) + self.assertEqual(s.op.name, 'name_with_spaces') + if __name__ == '__main__': tf.test.main()