Fix LSTMs in TPUStrategy. We need to check the outer graph for the control flow context to find out whether we're in a tpu.replicate().

PiperOrigin-RevId: 263821933
This commit is contained in:
Jonathan Hseu 2019-08-16 12:16:53 -07:00 committed by TensorFlower Gardener
parent 72167ef255
commit 3a73493dfe
2 changed files with 40 additions and 15 deletions

View File

@ -1089,13 +1089,20 @@ ops.register_tensor_conversion_function(MirroredVariable,
def _enclosing_tpu_context():
# pylint: disable=protected-access
tpu_context = ops.get_default_graph()._get_control_flow_context()
# pylint: enable=protected-access
while tpu_context is not None and not isinstance(
tpu_context, control_flow_ops.XLAControlFlowContext):
tpu_context = tpu_context.outer_context
return tpu_context
"""Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
graph = ops.get_default_graph()
while graph is not None:
# pylint: disable=protected-access
context_ = graph._get_control_flow_context()
# pylint: enable=protected-access
while context_ is not None:
if isinstance(context_, control_flow_ops.XLAControlFlowContext):
return context_
context_ = context_.outer_context
# This may be a FuncGraph due to defuns or v2 control flow. We need to
# find the original graph with the XLAControlFlowContext.
graph = getattr(graph, "outer_graph", None)
return None
def is_distributed_variable(v):

View File

@ -163,6 +163,23 @@ def core(num):
return "device:TPU_REPLICATED_CORE:{}".format(num)
def _enclosing_tpu_context_and_graph():
"""Returns the TPUReplicateContext and its associated graph."""
graph = ops.get_default_graph()
while graph is not None:
# pylint: disable=protected-access
context_ = graph._get_control_flow_context()
# pylint: enable=protected-access
while context_ is not None:
if isinstance(context_, TPUReplicateContext):
return context_, graph
context_ = context_.outer_context
graph = getattr(graph, "outer_graph", None)
raise ValueError("get_replicated_var_handle() called without "
"TPUReplicateContext. This shouldn't happen. Please file "
"a bug.")
class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
"""A `ControlFlowContext` for nodes inside a TPU computation.
@ -230,14 +247,15 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
# so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope
# instead.
# pylint: disable=protected-access
graph = ops.get_default_graph()
saved_context = graph._get_control_flow_context()
graph._set_control_flow_context(self.outer_context)
handle = tpu_ops.tpu_replicated_input(
[v.handle for v in vars_], name=name + "/handle")
graph._set_control_flow_context(saved_context)
# pylint: enable=protected-access
_, graph = _enclosing_tpu_context_and_graph()
with graph.as_default():
# pylint: disable=protected-access
saved_context = graph._get_control_flow_context()
graph._set_control_flow_context(self.outer_context)
handle = tpu_ops.tpu_replicated_input(
[v.handle for v in vars_], name=name + "/handle")
graph._set_control_flow_context(saved_context)
# pylint: enable=protected-access
self._replicated_vars[name] = handle
return handle