Throw an explicit error if user call TPUStrategy experimental_run_v2 in eager mode with a python function.
PiperOrigin-RevId: 283453288 Change-Id: I381a61afbaf6cb74ccd1ad1f556d8e5cf3f962f2
This commit is contained in:
parent
9bc6a80e2b
commit
41a576f505
@ -36,8 +36,9 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.strategies_minus_tpu,
|
||||
mode=["eager"]))
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testFullEager(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
|
||||
|
@ -37,7 +37,6 @@ from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device_spec
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -83,29 +82,6 @@ def maybe_init_scope():
|
||||
yield
|
||||
|
||||
|
||||
def validate_experimental_run_function(fn):
|
||||
"""Validate the function passed into strategy.experimental_run_v2."""
|
||||
|
||||
# We allow three types of functions/objects passed into TPUStrategy
|
||||
# experimental_run_v2 in eager mode:
|
||||
# 1. a user annotated tf.function
|
||||
# 2. a ConcreteFunction, this is mostly what you get from loading a saved
|
||||
# model.
|
||||
# 3. a callable object and the `__call__` method itself is a tf.function.
|
||||
#
|
||||
# Otherwise we return an error, because we don't support eagerly running
|
||||
# experimental_run_v2 in TPUStrategy.
|
||||
|
||||
if context.executing_eagerly() and not isinstance(
|
||||
fn, def_function.Function) and not isinstance(
|
||||
fn, function.ConcreteFunction) and not (callable(fn) and isinstance(
|
||||
fn.__call__, def_function.Function)):
|
||||
raise NotImplementedError(
|
||||
"TPUStrategy.experimental_run_v2(fn, ...) does not support eager "
|
||||
"execution. Either convert `fn` into a tf.function or consider "
|
||||
"calling strategy.experimental_run_v2 inside a tf.function.")
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.TPUStrategy", v1=[])
|
||||
class TPUStrategy(distribute_lib.Strategy):
|
||||
"""TPU distribution strategy implementation."""
|
||||
@ -113,36 +89,14 @@ class TPUStrategy(distribute_lib.Strategy):
|
||||
def __init__(self,
|
||||
tpu_cluster_resolver=None,
|
||||
device_assignment=None):
|
||||
"""Synchronous training in TPU donuts or Pods.
|
||||
|
||||
To construct a TPUStrategy object, you need to run the
|
||||
initialization code as below:
|
||||
|
||||
```python
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
||||
```
|
||||
|
||||
While using distribution strategies, the variables created within strategy's
|
||||
scope will be replicated across all the replicas and can be kept in sync
|
||||
using all-reduce algorithms.
|
||||
|
||||
To run TF2 programs on TPUs, you can either use `.compile` and
|
||||
`.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
|
||||
training loop by calling `strategy.experimental_run_v2` directly. Note that
|
||||
TPUStrategy doesn't support pure eager execution, so please make sure the
|
||||
function passed into `strategy.experimental_run_v2` is a `tf.function` or
|
||||
`strategy.experimental_run_v2` us called inside a `tf.function` if running
|
||||
in eager mode.
|
||||
"""Initializes the TPUStrategy object.
|
||||
|
||||
Args:
|
||||
tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
|
||||
which provides information about the TPU cluster.
|
||||
which provides information about the TPU cluster.
|
||||
device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
|
||||
specify the placement of replicas on the TPU cluster. Currently only
|
||||
supports the usecase of using a single core within a TPU cluster.
|
||||
specify the placement of replicas on the TPU cluster. Currently only
|
||||
supports the usecase of using a single core within a TPU cluster.
|
||||
"""
|
||||
super(TPUStrategy, self).__init__(TPUExtended(
|
||||
self, tpu_cluster_resolver, device_assignment=device_assignment))
|
||||
@ -157,8 +111,6 @@ class TPUStrategy(distribute_lib.Strategy):
|
||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||
"""See base class."""
|
||||
validate_experimental_run_function(fn)
|
||||
|
||||
# Note: the target function is converted to graph even when in Eager mode,
|
||||
# so autograph is on by default here.
|
||||
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
|
||||
@ -205,8 +157,6 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
|
||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||
"""See base class."""
|
||||
validate_experimental_run_function(fn)
|
||||
|
||||
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
|
||||
return self.extended.tpu_run(fn, args, kwargs)
|
||||
|
||||
@ -749,7 +699,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
]
|
||||
|
||||
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
|
||||
if result[0] is None or isinstance(result[0], ops.Operation):
|
||||
if result[0] is None:
|
||||
replicate_outputs = [None] * len(replicate_outputs)
|
||||
else:
|
||||
replicate_outputs = [
|
||||
|
@ -818,31 +818,13 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
|
||||
self.assertEqual(2., self.evaluate(add1(replica_local)))
|
||||
|
||||
|
||||
def mirrored_and_tpu_strategy_combinations():
|
||||
return combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
mode=["graph", "eager"])
|
||||
|
||||
|
||||
def strategy_and_run_tf_function_combinations():
|
||||
# Test the combination of different strategies and whether a tf.function
|
||||
# is passed into strategy.experimental_run_v2."""
|
||||
return combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
],
|
||||
mode=["graph", "eager"],
|
||||
experimental_run_tf_function=[True, False]) + combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
mode=["graph", "eager"],
|
||||
experimental_run_tf_function=[True])
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _assign_replica_local(self, v, new):
|
||||
@ -860,7 +842,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
save_path, _ = self._save_return_saver(sess, var)
|
||||
return save_path
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution):
|
||||
with self.cached_session() as sess:
|
||||
v, replica_local = _make_replica_local(
|
||||
@ -881,7 +862,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
saver.restore(sess, save_path)
|
||||
self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution):
|
||||
if context.num_gpus() < 1 and context.executing_eagerly():
|
||||
self.skipTest("A GPU is not available for this test in eager mode.")
|
||||
@ -998,46 +978,36 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
saver.restore(sess, save_path)
|
||||
self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]]))
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution):
|
||||
save_path = self._save_replica_local_mean(distribution)
|
||||
self._restore_replica_local_mean(save_path, distribution)
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution):
|
||||
save_path = self._save_replica_local_sum(distribution)
|
||||
self._restore_replica_local_sum(save_path, distribution)
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testSaveReplicaLocalMeanRestoreNormal(self, distribution):
|
||||
save_path = self._save_replica_local_mean(distribution)
|
||||
self._restore_normal(save_path)
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testSaveReplicaLocalSumRestoreNormal(self, distribution):
|
||||
save_path = self._save_replica_local_sum(distribution)
|
||||
self._restore_normal(save_path)
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testSaveNormalRestoreReplicaLocalMean(self, distribution):
|
||||
save_path = self._save_normal()
|
||||
self._restore_replica_local_mean(save_path, distribution)
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testSaveNormalRestoreReplicaLocalSum(self, distribution):
|
||||
save_path = self._save_normal()
|
||||
self._restore_replica_local_sum(save_path, distribution)
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
def testAssign(self, distribution, experimental_run_tf_function):
|
||||
|
||||
def testAssign(self, distribution):
|
||||
def assign(fn, v, update_value, cross_replica):
|
||||
update_fn = lambda: getattr(v, fn)(update_value)
|
||||
if cross_replica:
|
||||
return update_fn()
|
||||
else:
|
||||
if experimental_run_tf_function:
|
||||
update_fn = def_function.function(update_fn)
|
||||
return distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(update_fn))
|
||||
updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
|
||||
@ -1063,17 +1033,12 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(self.evaluate(component.read_value()),
|
||||
self.evaluate(array_ops.ones_like(component)))
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
def testAssignDtypeConversion(self, distribution,
|
||||
experimental_run_tf_function):
|
||||
|
||||
def testAssignDtypeConversion(self, distribution):
|
||||
def assign(fn, v, update_value, cross_replica):
|
||||
update_fn = lambda: getattr(v, fn)(update_value)
|
||||
if cross_replica:
|
||||
return update_fn()
|
||||
else:
|
||||
if experimental_run_tf_function:
|
||||
update_fn = def_function.function(update_fn)
|
||||
return distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(update_fn))
|
||||
updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
|
||||
@ -1099,7 +1064,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(self.evaluate(component.read_value()),
|
||||
self.evaluate(array_ops.ones_like(component)))
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testAssignWithAggregationSum(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
@ -1112,7 +1076,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(self.evaluate(component.read_value()),
|
||||
self.evaluate(array_ops.ones_like(component)))
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testAssignAddSubWithAggregationSum(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
@ -1127,9 +1090,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
ValueError, "SyncOnReadVariable does not support "):
|
||||
self.evaluate(v.assign_sub(1.))
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
def testReadValueInReplicaContext(self, distribution,
|
||||
experimental_run_tf_function):
|
||||
def testReadValueInReplicaContext(self, distribution):
|
||||
aggregations = [
|
||||
variables_lib.VariableAggregation.NONE,
|
||||
variables_lib.VariableAggregation.SUM,
|
||||
@ -1143,19 +1104,12 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
synchronization=variables_lib.VariableSynchronization.ON_READ,
|
||||
aggregation=aggregation)
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
if experimental_run_tf_function:
|
||||
read_var_fn = def_function.function(v.read_value)
|
||||
else:
|
||||
read_var_fn = v.read_value
|
||||
results = self.evaluate(
|
||||
distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(read_var_fn)))
|
||||
results = self.evaluate(distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(v.read_value)))
|
||||
for component, value in zip(v._values, results):
|
||||
self.assertAllEqual(self.evaluate(component.read_value()), value)
|
||||
|
||||
@combinations.generate(strategy_and_run_tf_function_combinations())
|
||||
def testReadValueInCrossReplicaContext(self, distribution,
|
||||
experimental_run_tf_function):
|
||||
def testReadValueInCrossReplicaContext(self, distribution):
|
||||
aggregations = [
|
||||
variables_lib.VariableAggregation.SUM,
|
||||
variables_lib.VariableAggregation.MEAN,
|
||||
@ -1171,15 +1125,10 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
synchronization=variables_lib.VariableSynchronization.ON_READ,
|
||||
aggregation=aggregation)
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
|
||||
def assign(v=v):
|
||||
ctx = distribution_strategy_context.get_replica_context()
|
||||
replica_id = ctx.replica_id_in_sync_group
|
||||
return v.assign(math_ops.cast(replica_id, dtypes.float32))
|
||||
|
||||
if experimental_run_tf_function:
|
||||
assign = def_function.function(assign)
|
||||
|
||||
self.evaluate(distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(assign)))
|
||||
result = self.evaluate(v.read_value())
|
||||
@ -1193,7 +1142,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
expected = 0
|
||||
self.assertEqual(expected, result, aggregation)
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
@ -1205,7 +1153,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
|
||||
ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
|
||||
self.evaluate(v.read_value())
|
||||
|
||||
@combinations.generate(mirrored_and_tpu_strategy_combinations())
|
||||
def testInitializedToSameValueInsideEagerRun(self, distribution):
|
||||
if not context.executing_eagerly(): self.skipTest("eager only")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user