From e906cdde1b8e52d359a1d07e2f11948dd5b9ae81 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Mon, 17 Jun 2019 21:51:02 -0700 Subject: [PATCH] Extend AutoGraph to functions called by Distribution Strategy's experimental_run_v2. PiperOrigin-RevId: 253722908 --- tensorflow/python/autograph/impl/api.py | 10 ++++--- .../distribute/custom_training_loop_test.py | 26 +++++++++++++++++++ .../python/distribute/distribute_lib.py | 3 +++ tensorflow/python/distribute/tpu_strategy.py | 5 +++- 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index dc92f1a990f..36a80df3572 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -201,7 +201,6 @@ def convert(recursive=False, optional_features=None, force_conversion=True): def decorator(f): """Decorator implementation.""" - @functools.wraps(f) def wrapper(*args, **kwargs): """Wrapper that calls the converted version of f.""" with ag_ctx.ControlStatusCtx( @@ -220,12 +219,15 @@ def convert(recursive=False, optional_features=None, force_conversion=True): else: raise - wrapper = tf_decorator.make_decorator(f, wrapper) + if inspect.isfunction(f) or inspect.ismethod(f): + wrapper = functools.update_wrapper(wrapper, f) + + decorated_wrapper = tf_decorator.make_decorator(f, wrapper) # Sometimes the decorator is just desugared, making it impossible to detect. # This attribute makes detection easier. - setattr(wrapper, '__ag_compiled', True) - return wrapper + setattr(decorated_wrapper, '__ag_compiled', True) + return decorated_wrapper return decorator diff --git a/tensorflow/python/distribute/custom_training_loop_test.py b/tensorflow/python/distribute/custom_training_loop_test.py index d4dc28c3680..b0a31541d86 100644 --- a/tensorflow/python/distribute/custom_training_loop_test.py +++ b/tensorflow/python/distribute/custom_training_loop_test.py @@ -97,6 +97,32 @@ class InputIterationTest(test.TestCase, parameterized.TestCase): results.append(output) self._validate_outputs(results) + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.strategies_minus_tpu + + [strategy_combinations.tpu_strategy_one_step], + mode=["eager"] + )) + def testRunInFunctionAutoGraphApplication(self, distribution): + dataset = self._get_dataset() + + def train_step(data): + if math_ops.reduce_sum(data) < 0: + return -data + return data + + @def_function.function + def f_train_step(input_data): + return distribution.experimental_local_results( + distribution.experimental_run_v2(train_step, args=(input_data,))) + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + results = [] + for x in dist_dataset: + output = f_train_step(x) + results.append(output) + self._validate_outputs(results) + @combinations.generate( combinations.combine( distribution=strategy_combinations.strategies_minus_tpu, diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index c17382e26db..496ef8f39a9 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -101,6 +101,8 @@ import threading import weakref import six +from tensorflow.python.autograph.core import ag_ctx +from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribution_strategy_context @@ -716,6 +718,7 @@ class Strategy(object): (for example, if running on a single replica). """ with self.scope(): + fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx()) return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) def reduce(self, reduce_op, value, axis): diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 40bbdd978e9..b4096ec25b4 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -25,6 +25,8 @@ import weakref import numpy as np +from tensorflow.python.autograph.core import ag_ctx +from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib @@ -160,10 +162,10 @@ class TPUStrategy(distribute_lib.Strategy): # This implementation runs a single step. It does not use infeed or outfeed. def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" + fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx()) return self.extended.tpu_run(fn, args, kwargs) - @tf_export(v1=["distribute.experimental.TPUStrategy"]) class TPUStrategyV1(distribute_lib.StrategyV1): """TPU distribution strategy implementation.""" @@ -199,6 +201,7 @@ class TPUStrategyV1(distribute_lib.StrategyV1): # This implementation runs a single step. It does not use infeed or outfeed. def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" + fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx()) return self.extended.tpu_run(fn, args, kwargs)