Throw an explicit error if user call TPUStrategy experimental_run_v2 in eager mode with a python function.

PiperOrigin-RevId: 283453288
Change-Id: I381a61afbaf6cb74ccd1ad1f556d8e5cf3f962f2
This commit is contained in:
A. Unique TensorFlower 2019-12-02 17:34:52 -08:00 committed by TensorFlower Gardener
parent 9bc6a80e2b
commit 41a576f505
3 changed files with 21 additions and 123 deletions

View File

@ -36,8 +36,9 @@ class InputIterationTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu,
mode=["eager"]))
distribution=strategy_combinations.all_strategies,
mode=["eager"]
))
def testFullEager(self, distribution):
dataset = self._get_dataset()

View File

@ -37,7 +37,6 @@ from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device_spec
from tensorflow.python.framework import dtypes
@ -83,29 +82,6 @@ def maybe_init_scope():
yield
def validate_experimental_run_function(fn):
"""Validate the function passed into strategy.experimental_run_v2."""
# We allow three types of functions/objects passed into TPUStrategy
# experimental_run_v2 in eager mode:
# 1. a user annotated tf.function
# 2. a ConcreteFunction, this is mostly what you get from loading a saved
# model.
# 3. a callable object and the `__call__` method itself is a tf.function.
#
# Otherwise we return an error, because we don't support eagerly running
# experimental_run_v2 in TPUStrategy.
if context.executing_eagerly() and not isinstance(
fn, def_function.Function) and not isinstance(
fn, function.ConcreteFunction) and not (callable(fn) and isinstance(
fn.__call__, def_function.Function)):
raise NotImplementedError(
"TPUStrategy.experimental_run_v2(fn, ...) does not support eager "
"execution. Either convert `fn` into a tf.function or consider "
"calling strategy.experimental_run_v2 inside a tf.function.")
@tf_export("distribute.experimental.TPUStrategy", v1=[])
class TPUStrategy(distribute_lib.Strategy):
"""TPU distribution strategy implementation."""
@ -113,36 +89,14 @@ class TPUStrategy(distribute_lib.Strategy):
def __init__(self,
tpu_cluster_resolver=None,
device_assignment=None):
"""Synchronous training in TPU donuts or Pods.
To construct a TPUStrategy object, you need to run the
initialization code as below:
```python
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
```
While using distribution strategies, the variables created within strategy's
scope will be replicated across all the replicas and can be kept in sync
using all-reduce algorithms.
To run TF2 programs on TPUs, you can either use `.compile` and
`.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
training loop by calling `strategy.experimental_run_v2` directly. Note that
TPUStrategy doesn't support pure eager execution, so please make sure the
function passed into `strategy.experimental_run_v2` is a `tf.function` or
`strategy.experimental_run_v2` us called inside a `tf.function` if running
in eager mode.
"""Initializes the TPUStrategy object.
Args:
tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
which provides information about the TPU cluster.
device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
specify the placement of replicas on the TPU cluster. Currently only
supports the usecase of using a single core within a TPU cluster.
specify the placement of replicas on the TPU cluster. Currently only
supports the usecase of using a single core within a TPU cluster.
"""
super(TPUStrategy, self).__init__(TPUExtended(
self, tpu_cluster_resolver, device_assignment=device_assignment))
@ -157,8 +111,6 @@ class TPUStrategy(distribute_lib.Strategy):
# This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None):
"""See base class."""
validate_experimental_run_function(fn)
# Note: the target function is converted to graph even when in Eager mode,
# so autograph is on by default here.
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
@ -205,8 +157,6 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
# This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None):
"""See base class."""
validate_experimental_run_function(fn)
fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs)
@ -749,7 +699,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
]
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
if result[0] is None or isinstance(result[0], ops.Operation):
if result[0] is None:
replicate_outputs = [None] * len(replicate_outputs)
else:
replicate_outputs = [

View File

@ -818,31 +818,13 @@ class SyncOnReadVariablePropertiesTest(test.TestCase):
self.assertEqual(2., self.evaluate(add1(replica_local)))
def mirrored_and_tpu_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
mode=["graph", "eager"])
def strategy_and_run_tf_function_combinations():
# Test the combination of different strategies and whether a tf.function
# is passed into strategy.experimental_run_v2."""
return combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode=["graph", "eager"],
experimental_run_tf_function=[True, False]) + combinations.combine(
distribution=[
strategy_combinations.tpu_strategy,
],
mode=["graph", "eager"],
experimental_run_tf_function=[True])
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
mode=["graph", "eager"]))
class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
def _assign_replica_local(self, v, new):
@ -860,7 +842,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
save_path, _ = self._save_return_saver(sess, var)
return save_path
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution):
with self.cached_session() as sess:
v, replica_local = _make_replica_local(
@ -881,7 +862,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
saver.restore(sess, save_path)
self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
@ -998,46 +978,36 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
saver.restore(sess, save_path)
self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]]))
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution):
save_path = self._save_replica_local_mean(distribution)
self._restore_replica_local_mean(save_path, distribution)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution):
save_path = self._save_replica_local_sum(distribution)
self._restore_replica_local_sum(save_path, distribution)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveReplicaLocalMeanRestoreNormal(self, distribution):
save_path = self._save_replica_local_mean(distribution)
self._restore_normal(save_path)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveReplicaLocalSumRestoreNormal(self, distribution):
save_path = self._save_replica_local_sum(distribution)
self._restore_normal(save_path)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveNormalRestoreReplicaLocalMean(self, distribution):
save_path = self._save_normal()
self._restore_replica_local_mean(save_path, distribution)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testSaveNormalRestoreReplicaLocalSum(self, distribution):
save_path = self._save_normal()
self._restore_replica_local_sum(save_path, distribution)
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssign(self, distribution, experimental_run_tf_function):
def testAssign(self, distribution):
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return distribution.experimental_local_results(
distribution.experimental_run_v2(update_fn))
updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)]
@ -1063,17 +1033,12 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testAssignDtypeConversion(self, distribution,
experimental_run_tf_function):
def testAssignDtypeConversion(self, distribution):
def assign(fn, v, update_value, cross_replica):
update_fn = lambda: getattr(v, fn)(update_value)
if cross_replica:
return update_fn()
else:
if experimental_run_tf_function:
update_fn = def_function.function(update_fn)
return distribution.experimental_local_results(
distribution.experimental_run_v2(update_fn))
updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)]
@ -1099,7 +1064,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testAssignWithAggregationSum(self, distribution):
with distribution.scope():
v = variable_scope.variable(
@ -1112,7 +1076,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(self.evaluate(component.read_value()),
self.evaluate(array_ops.ones_like(component)))
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testAssignAddSubWithAggregationSum(self, distribution):
with distribution.scope():
v = variable_scope.variable(
@ -1127,9 +1090,7 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
ValueError, "SyncOnReadVariable does not support "):
self.evaluate(v.assign_sub(1.))
@combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInReplicaContext(self, distribution,
experimental_run_tf_function):
def testReadValueInReplicaContext(self, distribution):
aggregations = [
variables_lib.VariableAggregation.NONE,
variables_lib.VariableAggregation.SUM,
@ -1143,19 +1104,12 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
if experimental_run_tf_function:
read_var_fn = def_function.function(v.read_value)
else:
read_var_fn = v.read_value
results = self.evaluate(
distribution.experimental_local_results(
distribution.experimental_run_v2(read_var_fn)))
results = self.evaluate(distribution.experimental_local_results(
distribution.experimental_run_v2(v.read_value)))
for component, value in zip(v._values, results):
self.assertAllEqual(self.evaluate(component.read_value()), value)
@combinations.generate(strategy_and_run_tf_function_combinations())
def testReadValueInCrossReplicaContext(self, distribution,
experimental_run_tf_function):
def testReadValueInCrossReplicaContext(self, distribution):
aggregations = [
variables_lib.VariableAggregation.SUM,
variables_lib.VariableAggregation.MEAN,
@ -1171,15 +1125,10 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
synchronization=variables_lib.VariableSynchronization.ON_READ,
aggregation=aggregation)
self.evaluate(variables_lib.global_variables_initializer())
def assign(v=v):
ctx = distribution_strategy_context.get_replica_context()
replica_id = ctx.replica_id_in_sync_group
return v.assign(math_ops.cast(replica_id, dtypes.float32))
if experimental_run_tf_function:
assign = def_function.function(assign)
self.evaluate(distribution.experimental_local_results(
distribution.experimental_run_v2(assign)))
result = self.evaluate(v.read_value())
@ -1193,7 +1142,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
expected = 0
self.assertEqual(expected, result, aggregation)
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
with distribution.scope():
v = variable_scope.variable(
@ -1205,7 +1153,6 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
ValueError, "Could not convert from .* VariableAggregation\\.NONE"):
self.evaluate(v.read_value())
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testInitializedToSameValueInsideEagerRun(self, distribution):
if not context.executing_eagerly(): self.skipTest("eager only")