Run gpu doctest with a fixed set of devices

This is to ensure that the test can run with different physical GPUs, and we can
change our examples to use two GPUs which are more real.

PiperOrigin-RevId: 320116796
Change-Id: I34e486db07b5f3be6f952595428dfab9c2c09aa4
This commit is contained in:
Xiao Yu 2020-07-07 20:51:29 -07:00 committed by TensorFlower Gardener
parent 087d3651ba
commit a7ee6a72ff
6 changed files with 96 additions and 123 deletions

View File

@ -533,7 +533,7 @@ class ValueContext(object):
2. Passed in by `experimental_distribute_values_from_function`.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> def value_fn(value_context):
... return value_context.num_replicas_in_sync
>>> distributed_values = (
@ -541,7 +541,7 @@ class ValueContext(object):
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(2, 2)
(1,)
"""
@ -792,14 +792,13 @@ class StrategyBase(object):
This method returns a context manager, and is used as follows:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> # Variable created inside scope:
>>> with strategy.scope():
... mirrored_variable = tf.Variable(1.)
>>> mirrored_variable
MirroredVariable:{
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
}
>>> # Variable created outside scope:
>>> regular_variable = tf.Variable(1.)
@ -1158,21 +1157,18 @@ class StrategyBase(object):
1. Constant tensor input.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> tensor_input = tf.constant(3.0)
>>> @tf.function
... def replica_fn(input):
... return input*2.0
>>> result = strategy.run(replica_fn, args=(tensor_input,))
>>> result
PerReplica:{
0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>,
1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
}
<tf.Tensor: shape=(), dtype=float32, numpy=6.0>
2. DistributedValues input.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> @tf.function
... def run():
... def value_fn(value_context):
@ -1185,7 +1181,7 @@ class StrategyBase(object):
... return strategy.run(replica_fn2, args=(distributed_values,))
>>> result = run()
>>> result
<tf.Tensor: shape=(), dtype=int32, numpy=4>
<tf.Tensor: shape=(), dtype=int32, numpy=2>
Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
@ -1222,7 +1218,7 @@ class StrategyBase(object):
def reduce(self, reduce_op, value, axis):
"""Reduce `value` across replicas and return result on current device.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> def step_fn():
... i = tf.distribute.get_replica_context().replica_id_in_sync_group
... return tf.identity(i)
@ -1230,7 +1226,7 @@ class StrategyBase(object):
>>> per_replica_result = strategy.run(step_fn)
>>> total = strategy.reduce("SUM", per_replica_result, axis=None)
>>> total
<tf.Tensor: shape=(), dtype=int32, numpy=1>
<tf.Tensor: shape=(), dtype=int32, numpy=0>
To see how this would look with multiple replicas, consider the same
example with MirroredStrategy with 2 GPUs:
@ -1753,7 +1749,7 @@ class Strategy(StrategyBase):
1. Return constant value per replica:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> def value_fn(ctx):
... return tf.constant(1.)
>>> distributed_values = (
@ -1761,12 +1757,11 @@ class Strategy(StrategyBase):
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
(<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,)
2. Distribute values in array based on replica_id:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> array_value = np.array([3., 2., 1.])
>>> def value_fn(ctx):
... return array_value[ctx.replica_id_in_sync_group]
@ -1775,11 +1770,11 @@ class Strategy(StrategyBase):
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(3.0, 2.0)
(3.0,)
3. Specify values using num_replicas_in_sync:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> def value_fn(ctx):
... return ctx.num_replicas_in_sync
>>> distributed_values = (
@ -1787,7 +1782,7 @@ class Strategy(StrategyBase):
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(2, 2)
(1,)
4. Place values on devices and distribute:

View File

@ -165,7 +165,7 @@ class DistributedIteratorInterface(collections.Iterator,
Example use:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.range(100).batch(2)
>>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
>>> dist_dataset_iterator = iter(dist_dataset)
@ -176,8 +176,18 @@ class DistributedIteratorInterface(collections.Iterator,
>>> for _ in range(step_num):
... strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))
>>> strategy.experimental_local_results(dist_dataset_iterator.get_next())
(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([10, 11])>,)
The above example corresponds to the case where you have only one device. If
you have two devices, for example,
```python
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
```
Then the final line will print out:
```python
(<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>,
<tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>)
```
Returns:
A single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains
@ -197,14 +207,25 @@ class DistributedIteratorInterface(collections.Iterator,
Example usage:
>>> global_batch_size = 16
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
>>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> distributed_iterator.element_spec
(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))
The above example corresponds to the case where you have only one device. If
you have two devices, for example,
```python
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
```
Then the final line will print out:
```python
(PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
```
Returns:
A nested structure of `tf.TypeSpec` objects matching the structure of an
@ -223,7 +244,7 @@ class DistributedIteratorInterface(collections.Iterator,
Example usage:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> global_batch_size = 2
>>> steps_per_loop = 2
>>> dataset = tf.data.Dataset.range(10).batch(global_batch_size)
@ -291,8 +312,8 @@ class DistributedDatasetInterface(collections.Iterable,
* use a pythonic for-loop construct:
>>> global_batch_size = 4
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> global_batch_size = 2
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size)
>>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
>>> @tf.function
@ -303,14 +324,12 @@ class DistributedDatasetInterface(collections.Iterable,
... # train_step trains the model using the dataset elements
... loss = strategy.run(train_step, args=(x,))
... print("Loss is", loss)
Loss is PerReplica:{
0: tf.Tensor(
[[0.7]
[0.7]], shape=(2, 1), dtype=float32),
1: tf.Tensor(
Loss is tf.Tensor(
[[0.7]
[0.7]], shape=(2, 1), dtype=float32)
Loss is tf.Tensor(
[[0.7]
[0.7]], shape=(2, 1), dtype=float32)
}
Placing the loop inside a `tf.function` will give a performance boost.
However `break` and `return` are currently not supported if the loop is
@ -323,7 +342,7 @@ class DistributedDatasetInterface(collections.Iterable,
`tf.distribute.DistributedIterator`
>>> global_batch_size = 4
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size)
>>> train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
>>> @tf.function
@ -343,10 +362,10 @@ class DistributedDatasetInterface(collections.Iterable,
... total_loss += distributed_train_step(next(dist_dataset_iterator))
... num_batches += 1
... average_train_loss = total_loss / num_batches
... template = ("Epoch {}, Loss: {:.4f}")
... template = ("Epoch {}, Loss: {}")
... print (template.format(epoch+1, average_train_loss))
Epoch 1, Loss: 0.2000
Epoch 2, Loss: 0.2000
Epoch 1, Loss: 0.10000000894069672
Epoch 2, Loss: 0.10000000894069672
To achieve a performance improvement, you can also wrap the `strategy.run`
@ -370,10 +389,10 @@ class DistributedDatasetInterface(collections.Iterable,
For example:
>>> global_batch_size = 4
>>> global_batch_size = 2
>>> epochs = 1
>>> steps_per_epoch = 1
>>> mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> mirrored_strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size)
>>> dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
>>> @tf.function(input_signature=[dist_dataset.element_spec])
@ -386,14 +405,9 @@ class DistributedDatasetInterface(collections.Iterable,
... for _ in range(steps_per_epoch):
... output = train_step(next(iterator))
... print(output)
PerReplica:{
0: tf.Tensor(
[[4.]
[4.]], shape=(2, 1), dtype=float32),
1: tf.Tensor(
tf.Tensor(
[[4.]
[4.]], shape=(2, 1), dtype=float32)
}
Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
@ -408,14 +422,25 @@ class DistributedDatasetInterface(collections.Iterable,
Example usage:
>>> global_batch_size = 4
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size)
>>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> print(next(distributed_iterator))
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
The above example corresponds to the case where you have only one device. If
you have two devices, for example,
```python
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
```
Then the final line will print out:
```python
PerReplica:{
0: tf.Tensor([1 2], shape=(2,), dtype=int32),
1: tf.Tensor([3 4], shape=(2,), dtype=int32)
}
```
Returns:
An `tf.distribute.DistributedIterator` instance for the given
@ -431,14 +456,25 @@ class DistributedDatasetInterface(collections.Iterable,
Example usage:
>>> global_batch_size = 16
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
>>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
>>> dist_dataset.element_spec
(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))
The above example corresponds to the case where you have only one device. If
you have two devices, for example,
```python
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
```
Then the final line will print out:
```python
(PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
```
Returns:
A nested structure of `tf.TypeSpec` objects matching the structure of an

View File

@ -199,14 +199,13 @@ class MirroredStrategy(distribute_lib.Strategy):
will use the available CPUs. Note that TensorFlow treats all CPUs on a
machine as a single device, and uses threads internally for parallelism.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> with strategy.scope():
... x = tf.Variable(1.)
>>> x
MirroredVariable:{
0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
}
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
}
While using distribution strategies, all the variable creation should be done
within the strategy's scope. This will replicate the variables across all the
@ -220,15 +219,13 @@ class MirroredStrategy(distribute_lib.Strategy):
... def create_variable():
... if not x:
... x.append(tf.Variable(1.))
... return x[0]
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> with strategy.scope():
... _ = create_variable()
... print(x[0])
... create_variable()
... print (x[0])
MirroredVariable:{
0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>,
1: <tf.Variable ... shape=() dtype=float32, numpy=1.0>
}
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
}
`experimental_distribute_dataset` can be used to distribute the dataset across
the replicas when writing your own training loop. If you are using `.fit` and

View File

@ -93,14 +93,14 @@ class DistributedValues(object):
1. Created from a `tf.distribute.DistributedDataset`:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> distributed_values = next(dataset_iterator)
2. Returned by `run`:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> @tf.function
... def run():
... ctx = tf.distribute.get_replica_context()
@ -109,7 +109,7 @@ class DistributedValues(object):
3. As input into `run`:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> distributed_values = next(dataset_iterator)
@ -120,7 +120,7 @@ class DistributedValues(object):
4. Reduce value:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> distributed_values = next(dataset_iterator)
@ -128,16 +128,16 @@ class DistributedValues(object):
... distributed_values,
... axis = 0)
5. Inspect local replica values:
5. Inspect per replica values:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> per_replica_values = strategy.experimental_local_results(
... distributed_values)
>>> per_replica_values
(<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>)
(<tf.Tensor: shape=(2,), dtype=float32,
numpy=array([5., 6.], dtype=float32)>,)
"""

View File

@ -11,18 +11,7 @@ package(
exports_files(["LICENSE"])
tpu_module = [
"tpu.",
"distribute.tpu_strategy",
"distribute.cluster_resolver.tpu",
"distribute.cluster_resolver.tpu_oss",
]
# tf.distribute docstring often uses GPU, so they're only covered in
# tf_doctest_gpu.
distribute_module = [
"distribute.",
]
tpu_module = "tpu.,distribute.tpu_strategy,distribute.cluster_resolver.tpu,distribute.cluster_resolver.tpu_oss"
py_library(
name = "tf_doctest_lib",
@ -36,7 +25,7 @@ py_library(
py_test(
name = "tf_doctest",
srcs = ["tf_doctest.py"],
args = ["--module_prefix_skip=" + ",".join(tpu_module + distribute_module)],
args = ["--module_prefix_skip=" + tpu_module],
python_version = "PY3",
tags = [
"no_oss_py2",
@ -57,7 +46,7 @@ py_test(
tpu_py_test(
name = "tf_doctest_tpu",
srcs = ["tf_doctest.py"],
args = ["--module=" + ",".join(tpu_module)],
args = ["--module=" + tpu_module],
disable_experimental = True,
disable_v3 = True,
main = "tf_doctest.py",
@ -81,13 +70,11 @@ py_test(
srcs = ["tf_doctest.py"],
args = [
"--module=distribute.",
"--module_prefix_skip=" + ",".join(tpu_module),
"--required_gpus=2",
"--module_prefix_skip=" + tpu_module,
],
main = "tf_doctest.py",
python_version = "PY3",
tags = [
"gpu",
"no_oss_py2",
"no_pip",
"no_rocm",

View File

@ -46,8 +46,6 @@ flags.DEFINE_list('module_prefix_skip', [],
flags.DEFINE_boolean('list', None,
'List all the modules in the core package imported.')
flags.DEFINE_string('file', None, 'A specific file to run doctest on.')
flags.DEFINE_integer('required_gpus', 0,
'The number of GPUs required for the tests.')
flags.mark_flags_as_mutual_exclusive(['module', 'file'])
flags.mark_flags_as_mutual_exclusive(['list', 'file'])
@ -130,38 +128,6 @@ def get_module_and_inject_docstring(file_path):
return [file_module]
def setup_gpu(required_gpus):
"""Sets up the GPU devices.
If there're more available GPUs than needed, it hides the additional ones. If
there're less, it creates logical devices. This is to make sure the tests see
a fixed number of GPUs regardless of the environment.
Args:
required_gpus: an integer. The number of GPUs required.
Raises:
ValueError: if num_gpus is larger than zero but no GPU is available.
"""
if required_gpus == 0:
return
available_gpus = tf.config.experimental.list_physical_devices('GPU')
if not available_gpus:
raise ValueError('requires at least one physical GPU')
if len(available_gpus) >= required_gpus:
tf.config.set_visible_devices(available_gpus[:required_gpus])
else:
# Create logical GPUs out of one physical GPU for simplicity. Note that the
# other physical GPUs are still available and corresponds to one logical GPU
# each.
num_logical_gpus = required_gpus - len(available_gpus) + 1
logical_gpus = [
tf.config.LogicalDeviceConfiguration(memory_limit=256)
for _ in range(num_logical_gpus)
]
tf.config.set_logical_device_configuration(available_gpus[0], logical_gpus)
class TfTestCase(tf.test.TestCase):
def set_up(self, test):
@ -212,14 +178,6 @@ def load_tests(unused_loader, tests, unused_ignore):
))
return tests
# We can only create logical devices before initializing Tensorflow. This is
# called by unittest framework before running any test.
# https://docs.python.org/3/library/unittest.html#setupmodule-and-teardownmodule
def setUpModule():
setup_gpu(FLAGS.required_gpus)
if __name__ == '__main__':
recursive_import(tf_root)
absltest.main()