Extend AutoGraph to functions called by Distribution Strategy's experimental_run_v2.
PiperOrigin-RevId: 253722908
This commit is contained in:
parent
06a1ca8390
commit
e906cdde1b
@ -201,7 +201,6 @@ def convert(recursive=False, optional_features=None, force_conversion=True):
|
||||
def decorator(f):
|
||||
"""Decorator implementation."""
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Wrapper that calls the converted version of f."""
|
||||
with ag_ctx.ControlStatusCtx(
|
||||
@ -220,12 +219,15 @@ def convert(recursive=False, optional_features=None, force_conversion=True):
|
||||
else:
|
||||
raise
|
||||
|
||||
wrapper = tf_decorator.make_decorator(f, wrapper)
|
||||
if inspect.isfunction(f) or inspect.ismethod(f):
|
||||
wrapper = functools.update_wrapper(wrapper, f)
|
||||
|
||||
decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
|
||||
|
||||
# Sometimes the decorator is just desugared, making it impossible to detect.
|
||||
# This attribute makes detection easier.
|
||||
setattr(wrapper, '__ag_compiled', True)
|
||||
return wrapper
|
||||
setattr(decorated_wrapper, '__ag_compiled', True)
|
||||
return decorated_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
@ -97,6 +97,32 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
|
||||
results.append(output)
|
||||
self._validate_outputs(results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.strategies_minus_tpu +
|
||||
[strategy_combinations.tpu_strategy_one_step],
|
||||
mode=["eager"]
|
||||
))
|
||||
def testRunInFunctionAutoGraphApplication(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
|
||||
def train_step(data):
|
||||
if math_ops.reduce_sum(data) < 0:
|
||||
return -data
|
||||
return data
|
||||
|
||||
@def_function.function
|
||||
def f_train_step(input_data):
|
||||
return distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(train_step, args=(input_data,)))
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
results = []
|
||||
for x in dist_dataset:
|
||||
output = f_train_step(x)
|
||||
results.append(output)
|
||||
self._validate_outputs(results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.strategies_minus_tpu,
|
||||
|
@ -101,6 +101,8 @@ import threading
|
||||
import weakref
|
||||
import six
|
||||
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
@ -716,6 +718,7 @@ class Strategy(object):
|
||||
(for example, if running on a single replica).
|
||||
"""
|
||||
with self.scope():
|
||||
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
|
||||
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
|
||||
|
||||
def reduce(self, reduce_op, value, axis):
|
||||
|
@ -25,6 +25,8 @@ import weakref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
@ -160,10 +162,10 @@ class TPUStrategy(distribute_lib.Strategy):
|
||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||
"""See base class."""
|
||||
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
|
||||
return self.extended.tpu_run(fn, args, kwargs)
|
||||
|
||||
|
||||
|
||||
@tf_export(v1=["distribute.experimental.TPUStrategy"])
|
||||
class TPUStrategyV1(distribute_lib.StrategyV1):
|
||||
"""TPU distribution strategy implementation."""
|
||||
@ -199,6 +201,7 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
|
||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||
"""See base class."""
|
||||
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
|
||||
return self.extended.tpu_run(fn, args, kwargs)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user