Make TPUStrategy work with tf.function(experimental_compile=True). This involves two changes:
1. Only create replicated var handle inside TPUReplicateContext. 2. If the function annotated with experimental_compile=True is called inside a XLAControlFlowContext, don't create a new XLAControlFlowContext. PiperOrigin-RevId: 296086034 Change-Id: I821f3b3cd5ba69cd4c7bdb9c28e13e4b4c83f967
This commit is contained in:
parent
38168415ea
commit
0dd277c674
@ -620,6 +620,7 @@ py_library(
|
|||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/tpu:tpu_lib",
|
||||||
"//tensorflow/python/training/tracking:base",
|
"//tensorflow/python/training/tracking:base",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
|
@ -354,6 +354,50 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
|
|||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
model = CustomModel()
|
model = CustomModel()
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def train_step(iterator):
|
||||||
|
|
||||||
|
def step_fn(inputs):
|
||||||
|
images, targets = inputs
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
outputs = model(images)
|
||||||
|
loss = math_ops.reduce_sum(outputs - targets)
|
||||||
|
grads = tape.gradient(loss, model.variables)
|
||||||
|
return grads
|
||||||
|
|
||||||
|
outputs = distribution.experimental_run_v2(
|
||||||
|
step_fn, args=(next(iterator),))
|
||||||
|
return nest.map_structure(distribution.experimental_local_results,
|
||||||
|
outputs)
|
||||||
|
|
||||||
|
train_step(input_iterator)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
|
||||||
|
def test_tf_function_experimental_compile(self, distribution):
|
||||||
|
dataset = self._get_dataset()
|
||||||
|
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||||
|
|
||||||
|
class CustomDense(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self, num_outputs):
|
||||||
|
super(CustomDense, self).__init__()
|
||||||
|
self.num_outputs = num_outputs
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.kernel = self.add_variable(
|
||||||
|
"kernel", shape=[int(input_shape[-1]), self.num_outputs])
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def call(self, inputs):
|
||||||
|
return math_ops.matmul(inputs, self.kernel)
|
||||||
|
|
||||||
|
with distribution.scope():
|
||||||
|
x = keras.layers.Input(shape=(3,))
|
||||||
|
y = CustomDense(4)(x)
|
||||||
|
model = keras.Model(x, y)
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def train_step(iterator):
|
def train_step(iterator):
|
||||||
def step_fn(inputs):
|
def step_fn(inputs):
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.ops import gen_resource_variable_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
|
from tensorflow.python.tpu import tpu
|
||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training import saver
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -938,14 +939,14 @@ ops.register_tensor_conversion_function(Mirrored,
|
|||||||
|
|
||||||
|
|
||||||
def _enclosing_tpu_context():
|
def _enclosing_tpu_context():
|
||||||
"""Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
|
"""Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
|
||||||
graph = ops.get_default_graph()
|
graph = ops.get_default_graph()
|
||||||
while graph is not None:
|
while graph is not None:
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
context_ = graph._get_control_flow_context()
|
context_ = graph._get_control_flow_context()
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
while context_ is not None:
|
while context_ is not None:
|
||||||
if isinstance(context_, control_flow_ops.XLAControlFlowContext):
|
if isinstance(context_, tpu.TPUReplicateContext):
|
||||||
return context_
|
return context_
|
||||||
context_ = context_.outer_context
|
context_ = context_.outer_context
|
||||||
# This may be a FuncGraph due to defuns or v2 control flow. We need to
|
# This may be a FuncGraph due to defuns or v2 control flow. We need to
|
||||||
|
@ -689,6 +689,7 @@ py_library(
|
|||||||
":lift_to_graph",
|
":lift_to_graph",
|
||||||
"//tensorflow/python:cond_v2", # TODO(b/118513001): Imported via control_flow_ops; remove.
|
"//tensorflow/python:cond_v2", # TODO(b/118513001): Imported via control_flow_ops; remove.
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
|
"//tensorflow/python:control_flow_util",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import func_graph as func_graph_module
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -563,9 +564,12 @@ class Function(object):
|
|||||||
return self._python_function(*args, **kwds)
|
return self._python_function(*args, **kwds)
|
||||||
|
|
||||||
tracing_count = self._get_tracing_count()
|
tracing_count = self._get_tracing_count()
|
||||||
if self._experimental_compile:
|
if self._experimental_compile and (
|
||||||
|
not control_flow_util.GraphOrParentsInXlaContext(
|
||||||
|
ops.get_default_graph())):
|
||||||
# V2 control flow relies on XLAControlFlowContext to generate a
|
# V2 control flow relies on XLAControlFlowContext to generate a
|
||||||
# XLA-compatible function graph.
|
# XLA-compatible function graph. If the function is already called inside
|
||||||
|
# an XLA context, we don't create nested XLA context.
|
||||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
try:
|
try:
|
||||||
xla_context.Enter()
|
xla_context.Enter()
|
||||||
|
Loading…
Reference in New Issue
Block a user