Extend AutoGraph to functions called by Distribution Strategy's experimental_run_v2.

PiperOrigin-RevId: 253722908
This commit is contained in:
Dan Moldovan 2019-06-17 21:51:02 -07:00 committed by TensorFlower Gardener
parent 06a1ca8390
commit e906cdde1b
4 changed files with 39 additions and 5 deletions

View File

@ -201,7 +201,6 @@ def convert(recursive=False, optional_features=None, force_conversion=True):
def decorator(f):
"""Decorator implementation."""
@functools.wraps(f)
def wrapper(*args, **kwargs):
"""Wrapper that calls the converted version of f."""
with ag_ctx.ControlStatusCtx(
@ -220,12 +219,15 @@ def convert(recursive=False, optional_features=None, force_conversion=True):
else:
raise
wrapper = tf_decorator.make_decorator(f, wrapper)
if inspect.isfunction(f) or inspect.ismethod(f):
wrapper = functools.update_wrapper(wrapper, f)
decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
# Sometimes the decorator is just desugared, making it impossible to detect.
# This attribute makes detection easier.
setattr(wrapper, '__ag_compiled', True)
return wrapper
setattr(decorated_wrapper, '__ag_compiled', True)
return decorated_wrapper
return decorator

View File

@ -97,6 +97,32 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
results.append(output)
self._validate_outputs(results)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu +
[strategy_combinations.tpu_strategy_one_step],
mode=["eager"]
))
def testRunInFunctionAutoGraphApplication(self, distribution):
dataset = self._get_dataset()
def train_step(data):
if math_ops.reduce_sum(data) < 0:
return -data
return data
@def_function.function
def f_train_step(input_data):
return distribution.experimental_local_results(
distribution.experimental_run_v2(train_step, args=(input_data,)))
dist_dataset = distribution.experimental_distribute_dataset(dataset)
results = []
for x in dist_dataset:
output = f_train_step(x)
results.append(output)
self._validate_outputs(results)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu,

View File

@ -101,6 +101,8 @@ import threading
import weakref
import six
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
@ -716,6 +718,7 @@ class Strategy(object):
(for example, if running on a single replica).
"""
with self.scope():
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
def reduce(self, reduce_op, value, axis):

View File

@ -25,6 +25,8 @@ import weakref
import numpy as np
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
@ -160,10 +162,10 @@ class TPUStrategy(distribute_lib.Strategy):
# This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None):
"""See base class."""
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs)
@tf_export(v1=["distribute.experimental.TPUStrategy"])
class TPUStrategyV1(distribute_lib.StrategyV1):
"""TPU distribution strategy implementation."""
@ -199,6 +201,7 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
# This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None):
"""See base class."""
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs)