Extend AutoGraph to functions called by Distribution Strategy's experimental_run_v2.

PiperOrigin-RevId: 253722908
This commit is contained in:
Dan Moldovan 2019-06-17 21:51:02 -07:00 committed by TensorFlower Gardener
parent 06a1ca8390
commit e906cdde1b
4 changed files with 39 additions and 5 deletions

View File

@ -201,7 +201,6 @@ def convert(recursive=False, optional_features=None, force_conversion=True):
def decorator(f): def decorator(f):
"""Decorator implementation.""" """Decorator implementation."""
@functools.wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
"""Wrapper that calls the converted version of f.""" """Wrapper that calls the converted version of f."""
with ag_ctx.ControlStatusCtx( with ag_ctx.ControlStatusCtx(
@ -220,12 +219,15 @@ def convert(recursive=False, optional_features=None, force_conversion=True):
else: else:
raise raise
wrapper = tf_decorator.make_decorator(f, wrapper) if inspect.isfunction(f) or inspect.ismethod(f):
wrapper = functools.update_wrapper(wrapper, f)
decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
# Sometimes the decorator is just desugared, making it impossible to detect. # Sometimes the decorator is just desugared, making it impossible to detect.
# This attribute makes detection easier. # This attribute makes detection easier.
setattr(wrapper, '__ag_compiled', True) setattr(decorated_wrapper, '__ag_compiled', True)
return wrapper return decorated_wrapper
return decorator return decorator

View File

@ -97,6 +97,32 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
results.append(output) results.append(output)
self._validate_outputs(results) self._validate_outputs(results)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu +
[strategy_combinations.tpu_strategy_one_step],
mode=["eager"]
))
def testRunInFunctionAutoGraphApplication(self, distribution):
dataset = self._get_dataset()
def train_step(data):
if math_ops.reduce_sum(data) < 0:
return -data
return data
@def_function.function
def f_train_step(input_data):
return distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(input_data,)))
dist_dataset = distribution.experimental_distribute_dataset(dataset)
results = []
for x in dist_dataset:
output = f_train_step(x)
results.append(output)
self._validate_outputs(results)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu, distribution=strategy_combinations.strategies_minus_tpu,

View File

@ -101,6 +101,8 @@ import threading
import weakref import weakref
import six import six
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import device_util from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
@ -716,6 +718,7 @@ class Strategy(object):
(for example, if running on a single replica). (for example, if running on a single replica).
""" """
with self.scope(): with self.scope():
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
def reduce(self, reduce_op, value, axis): def reduce(self, reduce_op, value, axis):

View File

@ -25,6 +25,8 @@ import weakref
import numpy as np import numpy as np
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_lib
@ -160,10 +162,10 @@ class TPUStrategy(distribute_lib.Strategy):
# This implementation runs a single step. It does not use infeed or outfeed. # This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None): def experimental_run_v2(self, fn, args=(), kwargs=None):
"""See base class.""" """See base class."""
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs) return self.extended.tpu_run(fn, args, kwargs)
@tf_export(v1=["distribute.experimental.TPUStrategy"]) @tf_export(v1=["distribute.experimental.TPUStrategy"])
class TPUStrategyV1(distribute_lib.StrategyV1): class TPUStrategyV1(distribute_lib.StrategyV1):
"""TPU distribution strategy implementation.""" """TPU distribution strategy implementation."""
@ -199,6 +201,7 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
# This implementation runs a single step. It does not use infeed or outfeed. # This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None): def experimental_run_v2(self, fn, args=(), kwargs=None):
"""See base class.""" """See base class."""
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs) return self.extended.tpu_run(fn, args, kwargs)