Make TPUStrategy work with tf.function(experimental_compile=True). This involves two changes:

1. Only create replicated var handle inside TPUReplicateContext.
2. If the function annotated with experimental_compile=True is called inside a XLAControlFlowContext, don't create a new XLAControlFlowContext.

PiperOrigin-RevId: 296086034
Change-Id: I821f3b3cd5ba69cd4c7bdb9c28e13e4b4c83f967
This commit is contained in:
Ruoxin Sang 2020-02-19 16:36:25 -08:00 committed by TensorFlower Gardener
parent 38168415ea
commit 0dd277c674
5 changed files with 55 additions and 4 deletions

View File

@ -620,6 +620,7 @@ py_library(
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/training/tracking:base", "//tensorflow/python/training/tracking:base",
"@six_archive//:six", "@six_archive//:six",
], ],

View File

@ -354,6 +354,50 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
with distribution.scope(): with distribution.scope():
model = CustomModel() model = CustomModel()
@def_function.function
def train_step(iterator):
def step_fn(inputs):
images, targets = inputs
with backprop.GradientTape() as tape:
outputs = model(images)
loss = math_ops.reduce_sum(outputs - targets)
grads = tape.gradient(loss, model.variables)
return grads
outputs = distribution.experimental_run_v2(
step_fn, args=(next(iterator),))
return nest.map_structure(distribution.experimental_local_results,
outputs)
train_step(input_iterator)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
def test_tf_function_experimental_compile(self, distribution):
dataset = self._get_dataset()
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
class CustomDense(keras.layers.Layer):
def __init__(self, num_outputs):
super(CustomDense, self).__init__()
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_variable(
"kernel", shape=[int(input_shape[-1]), self.num_outputs])
@def_function.function(experimental_compile=True)
def call(self, inputs):
return math_ops.matmul(inputs, self.kernel)
with distribution.scope():
x = keras.layers.Input(shape=(3,))
y = CustomDense(4)(x)
model = keras.Model(x, y)
@def_function.function @def_function.function
def train_step(iterator): def train_step(iterator):
def step_fn(inputs): def step_fn(inputs):

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.tpu import tpu
from tensorflow.python.training import saver from tensorflow.python.training import saver
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -938,14 +939,14 @@ ops.register_tensor_conversion_function(Mirrored,
def _enclosing_tpu_context(): def _enclosing_tpu_context():
"""Returns the XLAControlFlowContext, which exists inside a tpu.rewrite().""" """Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
graph = ops.get_default_graph() graph = ops.get_default_graph()
while graph is not None: while graph is not None:
# pylint: disable=protected-access # pylint: disable=protected-access
context_ = graph._get_control_flow_context() context_ = graph._get_control_flow_context()
# pylint: enable=protected-access # pylint: enable=protected-access
while context_ is not None: while context_ is not None:
if isinstance(context_, control_flow_ops.XLAControlFlowContext): if isinstance(context_, tpu.TPUReplicateContext):
return context_ return context_
context_ = context_.outer_context context_ = context_.outer_context
# This may be a FuncGraph due to defuns or v2 control flow. We need to # This may be a FuncGraph due to defuns or v2 control flow. We need to

View File

@ -689,6 +689,7 @@ py_library(
":lift_to_graph", ":lift_to_graph",
"//tensorflow/python:cond_v2", # TODO(b/118513001): Imported via control_flow_ops; remove. "//tensorflow/python:cond_v2", # TODO(b/118513001): Imported via control_flow_ops; remove.
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops", "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:util", "//tensorflow/python:util",

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -563,9 +564,12 @@ class Function(object):
return self._python_function(*args, **kwds) return self._python_function(*args, **kwds)
tracing_count = self._get_tracing_count() tracing_count = self._get_tracing_count()
if self._experimental_compile: if self._experimental_compile and (
not control_flow_util.GraphOrParentsInXlaContext(
ops.get_default_graph())):
# V2 control flow relies on XLAControlFlowContext to generate a # V2 control flow relies on XLAControlFlowContext to generate a
# XLA-compatible function graph. # XLA-compatible function graph. If the function is already called inside
# an XLA context, we don't create nested XLA context.
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
try: try:
xla_context.Enter() xla_context.Enter()