diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index cbed12b36e7..7c62d71ebff 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -1089,13 +1089,20 @@ ops.register_tensor_conversion_function(MirroredVariable, def _enclosing_tpu_context(): - # pylint: disable=protected-access - tpu_context = ops.get_default_graph()._get_control_flow_context() - # pylint: enable=protected-access - while tpu_context is not None and not isinstance( - tpu_context, control_flow_ops.XLAControlFlowContext): - tpu_context = tpu_context.outer_context - return tpu_context + """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite().""" + graph = ops.get_default_graph() + while graph is not None: + # pylint: disable=protected-access + context_ = graph._get_control_flow_context() + # pylint: enable=protected-access + while context_ is not None: + if isinstance(context_, control_flow_ops.XLAControlFlowContext): + return context_ + context_ = context_.outer_context + # This may be a FuncGraph due to defuns or v2 control flow. We need to + # find the original graph with the XLAControlFlowContext. + graph = getattr(graph, "outer_graph", None) + return None def is_distributed_variable(v): diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index f8f7d3d2177..f65068b3f7f 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -163,6 +163,23 @@ def core(num): return "device:TPU_REPLICATED_CORE:{}".format(num) +def _enclosing_tpu_context_and_graph(): + """Returns the TPUReplicateContext and its associated graph.""" + graph = ops.get_default_graph() + while graph is not None: + # pylint: disable=protected-access + context_ = graph._get_control_flow_context() + # pylint: enable=protected-access + while context_ is not None: + if isinstance(context_, TPUReplicateContext): + return context_, graph + context_ = context_.outer_context + graph = getattr(graph, "outer_graph", None) + raise ValueError("get_replicated_var_handle() called without " + "TPUReplicateContext. This shouldn't happen. Please file " + "a bug.") + + class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): """A `ControlFlowContext` for nodes inside a TPU computation. @@ -230,14 +247,15 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope # instead. - # pylint: disable=protected-access - graph = ops.get_default_graph() - saved_context = graph._get_control_flow_context() - graph._set_control_flow_context(self.outer_context) - handle = tpu_ops.tpu_replicated_input( - [v.handle for v in vars_], name=name + "/handle") - graph._set_control_flow_context(saved_context) - # pylint: enable=protected-access + _, graph = _enclosing_tpu_context_and_graph() + with graph.as_default(): + # pylint: disable=protected-access + saved_context = graph._get_control_flow_context() + graph._set_control_flow_context(self.outer_context) + handle = tpu_ops.tpu_replicated_input( + [v.handle for v in vars_], name=name + "/handle") + graph._set_control_flow_context(saved_context) + # pylint: enable=protected-access self._replicated_vars[name] = handle return handle