From 0dd277c6746fa71a314d53e39e0cd1fe4aa931ff Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Wed, 19 Feb 2020 16:36:25 -0800 Subject: [PATCH] Make TPUStrategy work with tf.function(experimental_compile=True). This involves two changes: 1. Only create replicated var handle inside TPUReplicateContext. 2. If the function annotated with experimental_compile=True is called inside a XLAControlFlowContext, don't create a new XLAControlFlowContext. PiperOrigin-RevId: 296086034 Change-Id: I821f3b3cd5ba69cd4c7bdb9c28e13e4b4c83f967 --- tensorflow/python/distribute/BUILD | 1 + .../custom_training_loop_models_test.py | 44 +++++++++++++++++++ tensorflow/python/distribute/values.py | 5 ++- tensorflow/python/eager/BUILD | 1 + tensorflow/python/eager/def_function.py | 8 +++- 5 files changed, 55 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index bc6865c8617..a4e2795ce2e 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -620,6 +620,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/eager:context", + "//tensorflow/python/tpu:tpu_lib", "//tensorflow/python/training/tracking:base", "@six_archive//:six", ], diff --git a/tensorflow/python/distribute/custom_training_loop_models_test.py b/tensorflow/python/distribute/custom_training_loop_models_test.py index dcce40a2f80..6fafa43677c 100644 --- a/tensorflow/python/distribute/custom_training_loop_models_test.py +++ b/tensorflow/python/distribute/custom_training_loop_models_test.py @@ -354,6 +354,50 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase): with distribution.scope(): model = CustomModel() + @def_function.function + def train_step(iterator): + + def step_fn(inputs): + images, targets = inputs + with backprop.GradientTape() as tape: + outputs = model(images) + loss = math_ops.reduce_sum(outputs - targets) + grads = tape.gradient(loss, model.variables) + return grads + + outputs = distribution.experimental_run_v2( + step_fn, args=(next(iterator),)) + return nest.map_structure(distribution.experimental_local_results, + outputs) + + train_step(input_iterator) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.tpu_strategies, mode=["eager"])) + def test_tf_function_experimental_compile(self, distribution): + dataset = self._get_dataset() + input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) + + class CustomDense(keras.layers.Layer): + + def __init__(self, num_outputs): + super(CustomDense, self).__init__() + self.num_outputs = num_outputs + + def build(self, input_shape): + self.kernel = self.add_variable( + "kernel", shape=[int(input_shape[-1]), self.num_outputs]) + + @def_function.function(experimental_compile=True) + def call(self, inputs): + return math_ops.matmul(inputs, self.kernel) + + with distribution.scope(): + x = keras.layers.Input(shape=(3,)) + y = CustomDense(4)(x) + model = keras.Model(x, y) + @def_function.function def train_step(iterator): def step_fn(inputs): diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index baf3b8295dc..74e9c600cee 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.tpu import tpu from tensorflow.python.training import saver from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest @@ -938,14 +939,14 @@ ops.register_tensor_conversion_function(Mirrored, def _enclosing_tpu_context(): - """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite().""" + """Returns the TPUReplicateContext, which exists inside a tpu.rewrite().""" graph = ops.get_default_graph() while graph is not None: # pylint: disable=protected-access context_ = graph._get_control_flow_context() # pylint: enable=protected-access while context_ is not None: - if isinstance(context_, control_flow_ops.XLAControlFlowContext): + if isinstance(context_, tpu.TPUReplicateContext): return context_ context_ = context_.outer_context # This may be a FuncGraph due to defuns or v2 control flow. We need to diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 65d07846cea..7aef5da11f2 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -689,6 +689,7 @@ py_library( ":lift_to_graph", "//tensorflow/python:cond_v2", # TODO(b/118513001): Imported via control_flow_ops; remove. "//tensorflow/python:control_flow_ops", + "//tensorflow/python:control_flow_util", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:util", diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index a2bcb91918b..76af2d32c3e 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging @@ -563,9 +564,12 @@ class Function(object): return self._python_function(*args, **kwds) tracing_count = self._get_tracing_count() - if self._experimental_compile: + if self._experimental_compile and ( + not control_flow_util.GraphOrParentsInXlaContext( + ops.get_default_graph())): # V2 control flow relies on XLAControlFlowContext to generate a - # XLA-compatible function graph. + # XLA-compatible function graph. If the function is already called inside + # an XLA context, we don't create nested XLA context. xla_context = control_flow_ops.XLAControlFlowContext() try: xla_context.Enter()