Extend AutoGraph to functions called by Distribution Strategy's experimental_run_v2.
PiperOrigin-RevId: 253722908
This commit is contained in:
parent
06a1ca8390
commit
e906cdde1b
@ -201,7 +201,6 @@ def convert(recursive=False, optional_features=None, force_conversion=True):
|
|||||||
def decorator(f):
|
def decorator(f):
|
||||||
"""Decorator implementation."""
|
"""Decorator implementation."""
|
||||||
|
|
||||||
@functools.wraps(f)
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
"""Wrapper that calls the converted version of f."""
|
"""Wrapper that calls the converted version of f."""
|
||||||
with ag_ctx.ControlStatusCtx(
|
with ag_ctx.ControlStatusCtx(
|
||||||
@ -220,12 +219,15 @@ def convert(recursive=False, optional_features=None, force_conversion=True):
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
wrapper = tf_decorator.make_decorator(f, wrapper)
|
if inspect.isfunction(f) or inspect.ismethod(f):
|
||||||
|
wrapper = functools.update_wrapper(wrapper, f)
|
||||||
|
|
||||||
|
decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
|
||||||
|
|
||||||
# Sometimes the decorator is just desugared, making it impossible to detect.
|
# Sometimes the decorator is just desugared, making it impossible to detect.
|
||||||
# This attribute makes detection easier.
|
# This attribute makes detection easier.
|
||||||
setattr(wrapper, '__ag_compiled', True)
|
setattr(decorated_wrapper, '__ag_compiled', True)
|
||||||
return wrapper
|
return decorated_wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
@ -97,6 +97,32 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
|
|||||||
results.append(output)
|
results.append(output)
|
||||||
self._validate_outputs(results)
|
self._validate_outputs(results)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=strategy_combinations.strategies_minus_tpu +
|
||||||
|
[strategy_combinations.tpu_strategy_one_step],
|
||||||
|
mode=["eager"]
|
||||||
|
))
|
||||||
|
def testRunInFunctionAutoGraphApplication(self, distribution):
|
||||||
|
dataset = self._get_dataset()
|
||||||
|
|
||||||
|
def train_step(data):
|
||||||
|
if math_ops.reduce_sum(data) < 0:
|
||||||
|
return -data
|
||||||
|
return data
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f_train_step(input_data):
|
||||||
|
return distribution.experimental_local_results(
|
||||||
|
distribution.experimental_run_v2(train_step, args=(input_data,)))
|
||||||
|
|
||||||
|
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||||
|
results = []
|
||||||
|
for x in dist_dataset:
|
||||||
|
output = f_train_step(x)
|
||||||
|
results.append(output)
|
||||||
|
self._validate_outputs(results)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=strategy_combinations.strategies_minus_tpu,
|
distribution=strategy_combinations.strategies_minus_tpu,
|
||||||
|
@ -101,6 +101,8 @@ import threading
|
|||||||
import weakref
|
import weakref
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.core import ag_ctx
|
||||||
|
from tensorflow.python.autograph.impl import api as autograph
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
@ -716,6 +718,7 @@ class Strategy(object):
|
|||||||
(for example, if running on a single replica).
|
(for example, if running on a single replica).
|
||||||
"""
|
"""
|
||||||
with self.scope():
|
with self.scope():
|
||||||
|
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
|
||||||
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
|
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
|
||||||
|
|
||||||
def reduce(self, reduce_op, value, axis):
|
def reduce(self, reduce_op, value, axis):
|
||||||
|
@ -25,6 +25,8 @@ import weakref
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.core import ag_ctx
|
||||||
|
from tensorflow.python.autograph.impl import api as autograph
|
||||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
@ -160,10 +162,10 @@ class TPUStrategy(distribute_lib.Strategy):
|
|||||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
|
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
|
||||||
return self.extended.tpu_run(fn, args, kwargs)
|
return self.extended.tpu_run(fn, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["distribute.experimental.TPUStrategy"])
|
@tf_export(v1=["distribute.experimental.TPUStrategy"])
|
||||||
class TPUStrategyV1(distribute_lib.StrategyV1):
|
class TPUStrategyV1(distribute_lib.StrategyV1):
|
||||||
"""TPU distribution strategy implementation."""
|
"""TPU distribution strategy implementation."""
|
||||||
@ -199,6 +201,7 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
|
|||||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
|
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
|
||||||
return self.extended.tpu_run(fn, args, kwargs)
|
return self.extended.tpu_run(fn, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user