Fix LSTMs in TPUStrategy. We need to check the outer graph for the control flow context to find out whether we're in a tpu.replicate().
PiperOrigin-RevId: 263821933
This commit is contained in:
parent
72167ef255
commit
3a73493dfe
@ -1089,13 +1089,20 @@ ops.register_tensor_conversion_function(MirroredVariable,
|
||||
|
||||
|
||||
def _enclosing_tpu_context():
|
||||
# pylint: disable=protected-access
|
||||
tpu_context = ops.get_default_graph()._get_control_flow_context()
|
||||
# pylint: enable=protected-access
|
||||
while tpu_context is not None and not isinstance(
|
||||
tpu_context, control_flow_ops.XLAControlFlowContext):
|
||||
tpu_context = tpu_context.outer_context
|
||||
return tpu_context
|
||||
"""Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
|
||||
graph = ops.get_default_graph()
|
||||
while graph is not None:
|
||||
# pylint: disable=protected-access
|
||||
context_ = graph._get_control_flow_context()
|
||||
# pylint: enable=protected-access
|
||||
while context_ is not None:
|
||||
if isinstance(context_, control_flow_ops.XLAControlFlowContext):
|
||||
return context_
|
||||
context_ = context_.outer_context
|
||||
# This may be a FuncGraph due to defuns or v2 control flow. We need to
|
||||
# find the original graph with the XLAControlFlowContext.
|
||||
graph = getattr(graph, "outer_graph", None)
|
||||
return None
|
||||
|
||||
|
||||
def is_distributed_variable(v):
|
||||
|
@ -163,6 +163,23 @@ def core(num):
|
||||
return "device:TPU_REPLICATED_CORE:{}".format(num)
|
||||
|
||||
|
||||
def _enclosing_tpu_context_and_graph():
|
||||
"""Returns the TPUReplicateContext and its associated graph."""
|
||||
graph = ops.get_default_graph()
|
||||
while graph is not None:
|
||||
# pylint: disable=protected-access
|
||||
context_ = graph._get_control_flow_context()
|
||||
# pylint: enable=protected-access
|
||||
while context_ is not None:
|
||||
if isinstance(context_, TPUReplicateContext):
|
||||
return context_, graph
|
||||
context_ = context_.outer_context
|
||||
graph = getattr(graph, "outer_graph", None)
|
||||
raise ValueError("get_replicated_var_handle() called without "
|
||||
"TPUReplicateContext. This shouldn't happen. Please file "
|
||||
"a bug.")
|
||||
|
||||
|
||||
class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
"""A `ControlFlowContext` for nodes inside a TPU computation.
|
||||
|
||||
@ -230,14 +247,15 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
# so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope
|
||||
# instead.
|
||||
|
||||
# pylint: disable=protected-access
|
||||
graph = ops.get_default_graph()
|
||||
saved_context = graph._get_control_flow_context()
|
||||
graph._set_control_flow_context(self.outer_context)
|
||||
handle = tpu_ops.tpu_replicated_input(
|
||||
[v.handle for v in vars_], name=name + "/handle")
|
||||
graph._set_control_flow_context(saved_context)
|
||||
# pylint: enable=protected-access
|
||||
_, graph = _enclosing_tpu_context_and_graph()
|
||||
with graph.as_default():
|
||||
# pylint: disable=protected-access
|
||||
saved_context = graph._get_control_flow_context()
|
||||
graph._set_control_flow_context(self.outer_context)
|
||||
handle = tpu_ops.tpu_replicated_input(
|
||||
[v.handle for v in vars_], name=name + "/handle")
|
||||
graph._set_control_flow_context(saved_context)
|
||||
# pylint: enable=protected-access
|
||||
self._replicated_vars[name] = handle
|
||||
return handle
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user