From 96713dd4635040556fd5061159010b7934e74672 Mon Sep 17 00:00:00 2001
From: Shivani Agrawal <shivaniagrawal@google.com>
Date: Thu, 14 Mar 2019 17:01:07 -0700
Subject: [PATCH] Separating out summary_util for distribution strategy to
 avoid possible circular dependencies.

PiperOrigin-RevId: 238549028
---
 tensorflow/contrib/compiler/BUILD             |  2 +-
 tensorflow/contrib/compiler/xla.py            |  2 +-
 tensorflow/python/BUILD                       |  2 +-
 tensorflow/python/distribute/BUILD            | 10 ++++
 .../python/distribute/summary_op_util.py      | 48 +++++++++++++++++++
 tensorflow/python/ops/summary_op_util.py      | 26 ----------
 tensorflow/python/summary/summary.py          | 13 ++---
 7 files changed, 68 insertions(+), 35 deletions(-)
 create mode 100644 tensorflow/python/distribute/summary_op_util.py

diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index 79c61589112..839682afdc6 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -64,9 +64,9 @@ py_library(
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:platform",
-        "//tensorflow/python:summary_op_util",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
+        "//tensorflow/python/distribute:summary_op_util",
         "//tensorflow/python/estimator:estimator_py",
     ],
 )
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 238c6ab1366..2ccb27da12f 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -25,11 +25,11 @@ from six.moves import xrange  # pylint: disable=redefined-builtin
 from tensorflow.compiler.jit.ops import xla_ops
 from tensorflow.compiler.jit.ops import xla_ops_grad  # pylint: disable=unused-import
 from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.distribute import summary_op_util
 from tensorflow.python.estimator import model_fn as model_fn_lib
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import summary_op_util
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 818ee60fb43..e3c026c81c5 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -5342,7 +5342,6 @@ py_library(
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
-        ":distribute",
         ":framework",
         ":framework_for_generated_wrappers",
         ":platform",
@@ -5371,6 +5370,7 @@ py_library(
         ":summary_ops_gen",
         ":summary_ops_v2",
         ":util",
+        "//tensorflow/python/distribute:summary_op_util",
         "//tensorflow/python/eager:context",
         "@six_archive//:six",
     ],
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 308bd428d05..9f9e285cce2 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -469,6 +469,16 @@ py_test(
     ],
 )
 
+py_library(
+    name = "summary_op_util",
+    srcs = ["summary_op_util.py"],
+    deps = [
+        ":distribute_lib",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:tensor_util",
+    ],
+)
+
 py_library(
     name = "values",
     srcs = ["values.py"],
diff --git a/tensorflow/python/distribute/summary_op_util.py b/tensorflow/python/distribute/summary_op_util.py
new file mode 100644
index 00000000000..1c7086b365b
--- /dev/null
+++ b/tensorflow/python/distribute/summary_op_util.py
@@ -0,0 +1,48 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#==============================================================================
+"""Contains utility functions used by summary ops in distribution strategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.distribute import distribution_strategy_context
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+
+
+def skip_summary():
+  """Determines if summary should be skipped.
+
+  If using multiple replicas in distributed strategy, skip summaries on all
+  replicas except the first one (replica_id=0).
+
+  Returns:
+    True if the summary is skipped; False otherwise.
+  """
+
+  # TODO(priyag): Add a new optional argument that will provide multiple
+  # alternatives to override default behavior. (e.g. run on last replica,
+  # compute sum or mean across replicas).
+  replica_context = distribution_strategy_context.get_replica_context()
+  if not replica_context:
+    return False
+  # TODO(b/118385803): when replica_id of _TPUReplicaContext is properly
+  # initialized, remember to change here as well.
+  replica_id = replica_context.replica_id_in_sync_group
+  if isinstance(replica_id, ops.Tensor):
+    replica_id = tensor_util.constant_value(replica_id)
+  return replica_id and replica_id > 0
diff --git a/tensorflow/python/ops/summary_op_util.py b/tensorflow/python/ops/summary_op_util.py
index 93d8d50842b..37b80d5e20b 100644
--- a/tensorflow/python/ops/summary_op_util.py
+++ b/tensorflow/python/ops/summary_op_util.py
@@ -21,9 +21,7 @@ from __future__ import print_function
 import contextlib
 import re
 
-from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
 from tensorflow.python.platform import tf_logging
 
 
@@ -44,30 +42,6 @@ def collect(val, collections, default_collections):
 _INVALID_TAG_CHARACTERS = re.compile(r'[^-/\w\.]')
 
 
-def skip_summary():
-  """Determines if summary should be skipped.
-
-  If using multiple replicas in distributed strategy, skip summaries on all
-  replicas except the first one (replica_id=0).
-
-  Returns:
-    True if the summary is skipped; False otherwise.
-  """
-
-  # TODO(priyag): Add a new optional argument that will provide multiple
-  # alternatives to override default behavior. (e.g. run on last replica,
-  # compute sum or mean across replicas).
-  replica_context = distribution_strategy_context.get_replica_context()
-  if not replica_context:
-    return False
-  # TODO(b/118385803): when replica_id of _TPUReplicaContext is properly
-  # initialized, remember to change here as well.
-  replica_id = replica_context.replica_id_in_sync_group
-  if isinstance(replica_id, ops.Tensor):
-    replica_id = tensor_util.constant_value(replica_id)
-  return replica_id and replica_id > 0
-
-
 def clean_tag(name):
   """Cleans a tag. Removes illegal characters for instance.
 
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index a01feb3dde0..4802dbb6572 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -35,6 +35,7 @@ from tensorflow.core.util.event_pb2 import SessionLog
 from tensorflow.core.util.event_pb2 import TaggedRunMetadata
 # pylint: enable=unused-import
 
+from tensorflow.python.distribute import summary_op_util as _distribute_summary_op_util
 from tensorflow.python.eager import context as _context
 from tensorflow.python.framework import constant_op as _constant_op
 from tensorflow.python.framework import dtypes as _dtypes
@@ -74,7 +75,7 @@ def scalar(name, tensor, collections=None, family=None):
   Raises:
     ValueError: If tensor has the wrong shape or type.
   """
-  if _summary_op_util.skip_summary():
+  if _distribute_summary_op_util.skip_summary():
     return _constant_op.constant('')
   with _summary_op_util.summary_scope(
       name, family, values=[tensor]) as (tag, scope):
@@ -129,7 +130,7 @@ def image(name, tensor, max_outputs=3, collections=None, family=None):
     A scalar `Tensor` of type `string`. The serialized `Summary` protocol
     buffer.
   """
-  if _summary_op_util.skip_summary():
+  if _distribute_summary_op_util.skip_summary():
     return _constant_op.constant('')
   with _summary_op_util.summary_scope(
       name, family, values=[tensor]) as (tag, scope):
@@ -169,7 +170,7 @@ def histogram(name, values, collections=None, family=None):
     A scalar `Tensor` of type `string`. The serialized `Summary` protocol
     buffer.
   """
-  if _summary_op_util.skip_summary():
+  if _distribute_summary_op_util.skip_summary():
     return _constant_op.constant('')
   with _summary_op_util.summary_scope(
       name, family, values=[values],
@@ -216,7 +217,7 @@ def audio(name, tensor, sample_rate, max_outputs=3, collections=None,
     A scalar `Tensor` of type `string`. The serialized `Summary` protocol
     buffer.
   """
-  if _summary_op_util.skip_summary():
+  if _distribute_summary_op_util.skip_summary():
     return _constant_op.constant('')
   with _summary_op_util.summary_scope(
       name, family=family, values=[tensor]) as (tag, scope):
@@ -313,7 +314,7 @@ def tensor_summary(name,
 
   serialized_summary_metadata = summary_metadata.SerializeToString()
 
-  if _summary_op_util.skip_summary():
+  if _distribute_summary_op_util.skip_summary():
     return _constant_op.constant('')
   with _summary_op_util.summary_scope(
       name, family, values=[tensor]) as (tag, scope):
@@ -363,7 +364,7 @@ def merge(inputs, collections=None, name=None):
     raise RuntimeError(
         'Merging tf.summary.* ops is not compatible with eager execution. '
         'Use tf.contrib.summary instead.')
-  if _summary_op_util.skip_summary():
+  if _distribute_summary_op_util.skip_summary():
     return _constant_op.constant('')
   name = _summary_op_util.clean_tag(name)
   with _ops.name_scope(name, 'Merge', inputs):