Graduate TPUStrategy from experimental.

RELNOTES=Make TPUStrategy symbol non experimental.
PiperOrigin-RevId: 317482072
Change-Id: I8bf596729699cb02fa275dfb63855c2dc68c1d42
This commit is contained in:
Ruoxin Sang 2020-06-20 13:06:21 -07:00 committed by TensorFlower Gardener
parent e647a3b425
commit 39a2286a0e
11 changed files with 323 additions and 37 deletions

View File

@ -422,7 +422,7 @@ def enable_check_numerics(stack_height_limit=30,
tf.debugging.enable_check_numerics()
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
strategy = tf.distribute.experimental.TPUStrategy(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
with strategy.scope():
# ...
```

View File

@ -737,7 +737,7 @@ def enable_dump_debug_info(dump_root,
logdir, tensor_debug_mode="FULL_HEALTH")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
strategy = tf.distribute.experimental.TPUStrategy(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
with strategy.scope():
# ...
```

View File

@ -49,7 +49,7 @@ model.evaluate(dataset)
```python
# Create the strategy instance.
tpu_strategy = tf.distribute.experimental.TPUStrategy(resolver)
tpu_strategy = tf.distribute.TPUStrategy(resolver)
# Create the keras model under strategy.scope()

View File

@ -618,7 +618,7 @@ class InputOptions(
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
dataset = tf.data.Dataset.range(16)
distributed_dataset_on_host = (
@ -1462,17 +1462,17 @@ class Strategy(StrategyBase):
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[1, 1, 2],
computation_shape=[1, 1, 1, 2],
num_replicas=4)
strategy = tf.distribute.experimental.TPUStrategy(
resolver, device_assignment=device_assignment)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
iterator = iter(inputs)
@tf.function()
def step_fn(inputs):
output = tf.add(inputs, inputs)
// Add operation will be executed on logical device 0.
# Add operation will be executed on logical device 0.
output = strategy.experimental_assign_to_logical_device(output, 0)
return output
@ -1517,10 +1517,10 @@ class Strategy(StrategyBase):
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[2, 2, 2],
computation_shape=[1, 2, 2, 2],
num_replicas=1)
strategy = tf.distribute.experimental.TPUStrategy(
resolver, device_assignment=device_assignment)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
iterator = iter(inputs)
@ -1529,8 +1529,8 @@ class Strategy(StrategyBase):
inputs = strategy.experimental_split_to_logical_devices(
inputs, [1, 2, 4, 1])
// model() function will be executed on 8 logical devices with `inputs`
// split 2 * 4 ways.
# model() function will be executed on 8 logical devices with `inputs`
# split 2 * 4 ways.
output = model(inputs)
return output
@ -1571,10 +1571,10 @@ class Strategy(StrategyBase):
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
computation_shape=[1, 1, 2],
computation_shape=[1, 1, 1, 2],
num_replicas=4)
strategy = tf.distribute.experimental.TPUStrategy(
resolver, device_assignment=device_assignment)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment)
iterator = iter(inputs)
@ -1584,12 +1584,12 @@ class Strategy(StrategyBase):
images = strategy.experimental_split_to_logical_devices(
inputs, [1, 2, 4, 1])
// model() function will be executed on 8 logical devices with `inputs`
// split 2 * 4 ways.
# model() function will be executed on 8 logical devices with `inputs`
# split 2 * 4 ways.
output = model(inputs)
// For loss calculation, all logical devices share the same logits
// and labels.
# For loss calculation, all logical devices share the same logits
# and labels.
labels = strategy.experimental_replicate_to_logical_devices(labels)
output = strategy.experimental_replicate_to_logical_devices(output)
loss = loss_fn(labels, output)

View File

@ -190,9 +190,8 @@ class MirroredStrategy(distribute_lib.Strategy):
This strategy is typically used for training on one
machine with multiple GPUs. For TPUs, use
`tf.distribute.experimental.TPUStrategy`. To use `MirroredStrategy` with
multiple workers, please refer to
`tf.distribute.experimental.MultiWorkerMirroredStrategy`.
`tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers,
please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
For example, a variable created under a `MirroredStrategy` is a
`MirroredVariable`. If no devices are specified in the constructor argument of

View File

@ -23,6 +23,7 @@ import collections
import contextlib
import copy
import weakref
from absl import logging
import numpy as np
@ -57,6 +58,7 @@ from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.tpu import training_loop
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -97,9 +99,188 @@ def validate_run_function(fn):
"eager behavior is enabled.")
@tf_export("distribute.TPUStrategy", v1=[])
class TPUStrategyV2(distribute_lib.Strategy):
"""Synchronous training on TPUs and TPU Pods.
To construct a TPUStrategy object, you need to run the
initialization code as below:
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
>>> tf.config.experimental_connect_to_cluster(resolver)
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
>>> strategy = tf.distribute.TPUStrategy(resolver)
While using distribution strategies, the variables created within the
strategy's scope will be replicated across all the replicas and can be kept in
sync using all-reduce algorithms.
To run TF2 programs on TPUs, you can either use `.compile` and
`.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
training loop by calling `strategy.run` directly. Note that
TPUStrategy doesn't support pure eager execution, so please make sure the
function passed into `strategy.run` is a `tf.function` or
`strategy.run` is called inside a `tf.function` if eager
behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.
`experimental_distribute_datasets_from_function` and
`experimental_distribute_dataset` APIs can be used to distribute the dataset
across the TPU workers when writing your own training loop. If you are using
`fit` and `compile` methods available in `tf.keras.Model`, then Keras will
handle the distribution for you.
An example of writing customized training loop on TPUs:
>>> with strategy.scope():
... model = tf.keras.Sequential([
... tf.keras.layers.Dense(2, input_shape=(5,)),
... ])
... optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
>>> def dataset_fn(ctx):
... x = np.random.random((2, 5)).astype(np.float32)
... y = np.random.randint(2, size=(2, 1))
... dataset = tf.data.Dataset.from_tensor_slices((x, y))
... return dataset.repeat().batch(1, drop_remainder=True)
>>> dist_dataset = strategy.experimental_distribute_datasets_from_function(
... dataset_fn)
>>> iterator = iter(dist_dataset)
>>> @tf.function()
... def train_step(iterator):
...
... def step_fn(inputs):
... features, labels = inputs
... with tf.GradientTape() as tape:
... logits = model(features, training=True)
... loss = tf.keras.losses.sparse_categorical_crossentropy(
... labels, logits)
...
... grads = tape.gradient(loss, model.trainable_variables)
... optimizer.apply_gradients(zip(grads, model.trainable_variables))
...
... strategy.run(step_fn, args=(next(iterator),))
>>> train_step(iterator)
For the advanced use cases like model parallelism, you can set
`experimental_device_assignment` argument when creating TPUStrategy to specify
number of replicas and number of logical devices. Below is an example to
initialize TPU system with 2 logical devices and 1 replica.
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
>>> tf.config.experimental_connect_to_cluster(resolver)
>>> topology = tf.tpu.experimental.initialize_tpu_system(resolver)
>>> device_assignment = tf.tpu.experimental.DeviceAssignment.build(
... topology,
... computation_shape=[1, 1, 1, 2],
... num_replicas=1)
>>> strategy = tf.distribute.TPUStrategy(
... resolver, experimental_device_assignment=device_assignment)
Then you can run a `tf.add` operation only on logical device 0.
>>> @tf.function()
... def step_fn(inputs):
... features, _ = inputs
... output = tf.add(features, features)
...
... # Add operation will be executed on logical device 0.
... output = strategy.experimental_assign_to_logical_device(output, 0)
... return output
>>> dist_dataset = strategy.experimental_distribute_datasets_from_function(
... dataset_fn)
>>> iterator = iter(dist_dataset)
>>> strategy.run(step_fn, args=(next(iterator),))
"""
def __init__(self,
tpu_cluster_resolver=None,
experimental_device_assignment=None):
"""Synchronous training in TPU donuts or Pods.
Args:
tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster. If None, it will
assume running on a local TPU worker.
experimental_device_assignment: Optional
`tf.tpu.experimental.DeviceAssignment` to specify the placement of
replicas on the TPU cluster.
"""
super(TPUStrategyV2, self).__init__(TPUExtended(
self, tpu_cluster_resolver,
device_assignment=experimental_device_assignment))
distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_workers").set(self.extended.num_hosts)
distribute_lib.distribution_strategy_replica_gauge.get_cell(
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a
# function graph. It's only supported in eager mode.
self._enable_packed_variable_in_eager_mode = False
def run(self, fn, args=(), kwargs=None, options=None):
"""Run the computation defined by `fn` on each TPU replica.
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
`tf.distribute.DistributedValues`, such as those produced by a
`tf.distribute.DistributedDataset` from
`tf.distribute.Strategy.experimental_distribute_dataset` or
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`,
when `fn` is executed on a particular replica, it will be executed with the
component of `tf.distribute.DistributedValues` that correspond to that
replica.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `all_reduce`.
All arguments in `args` or `kwargs` should either be nest of tensors or
`tf.distribute.DistributedValues` containing tensors or composite tensors.
Example usage:
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
>>> tf.config.experimental_connect_to_cluster(resolver)
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
>>> strategy = tf.distribute.TPUStrategy(resolver)
>>> @tf.function
... def run():
... def value_fn(value_context):
... return value_context.num_replicas_in_sync
... distributed_values = (
... strategy.experimental_distribute_values_from_function(value_fn))
... def replica_fn(input):
... return input * 2
... return strategy.run(replica_fn, args=(distributed_values,))
>>> result = run()
Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be `tf.distribute.DistributedValues`, `Tensor`
objects, or `Tensor`s (for example, if running on a single replica).
"""
validate_run_function(fn)
# Note: the target function is converted to graph even when in Eager mode,
# so autograph is on by default here.
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
options = options or distribute_lib.RunOptions()
return self.extended.tpu_run(fn, args, kwargs, options)
@tf_export("distribute.experimental.TPUStrategy", v1=[])
@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy")
class TPUStrategy(distribute_lib.Strategy):
"""TPU distribution strategy implementation.
"""Synchronous training on TPUs and TPU Pods.
To construct a TPUStrategy object, you need to run the
initialization code as below:
@ -109,9 +290,9 @@ class TPUStrategy(distribute_lib.Strategy):
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
>>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
While using distribution strategies, the variables created within strategy's
scope will be replicated across all the replicas and can be kept in sync
using all-reduce algorithms.
While using distribution strategies, the variables created within the
strategy's scope will be replicated across all the replicas and can be kept in
sync using all-reduce algorithms.
To run TF2 programs on TPUs, you can either use `.compile` and
`.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
@ -131,9 +312,12 @@ class TPUStrategy(distribute_lib.Strategy):
tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
specify the placement of replicas on the TPU cluster. Currently only
supports the usecase of using a single core within a TPU cluster.
specify the placement of replicas on the TPU cluster.
"""
logging.warning(
"`tf.distribute.experimental.TPUStrategy` is deprecated, please use "
" the non experimental symbol `tf.distribute.TPUStrategy` instead.")
super(TPUStrategy, self).__init__(TPUExtended(
self, tpu_cluster_resolver, device_assignment=device_assignment))
distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")

View File

@ -132,7 +132,7 @@ class TPUEmbedding(tracking.AutoTrackable):
First lets look at the `TPUStrategy` mode. Initial setup looks like:
```python
strategy = tf.distribute.experimental.TPUStrategy(...)
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
@ -234,7 +234,7 @@ class TPUEmbedding(tracking.AutoTrackable):
"""Creates the TPUEmbedding mid level API object.
```python
strategy = tf.distribute.experimental.TPUStrategy(...)
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=tf.tpu.experimental.embedding.FeatureConfig(
@ -512,7 +512,7 @@ class TPUEmbedding(tracking.AutoTrackable):
ensure you understand the effect of applying a zero gradient.
```python
strategy = tf.distribute.experimental.TPUStrategy(...)
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
@ -603,7 +603,7 @@ class TPUEmbedding(tracking.AutoTrackable):
`(batch_size, max_sequence_length, dim)` instead.
```python
strategy = tf.distribute.experimental.TPUStrategy(...)
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
@ -1054,13 +1054,13 @@ class TPUEmbedding(tracking.AutoTrackable):
embedding tables. We expect that the batch size of each of the tensors in
features matches the per core batch size. This will automatically happen if
your input dataset is batched to the global batch size and you use
`tf.distribute.experimental.TPUStrategy`'s `experimental_distribute_dataset`
`tf.distribute.TPUStrategy`'s `experimental_distribute_dataset`
or if you use `experimental_distribute_datasets_from_function` and batch
to the per core batch size computed by the context passed to your input
function.
```python
strategy = tf.distribute.experimental.TPUStrategy(...)
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)

View File

@ -0,0 +1,99 @@
path: "tensorflow.distribute.TPUStrategy"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.tpu_strategy.TPUStrategyV2\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "extended"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'tpu_cluster_resolver\', \'experimental_device_assignment\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_assign_to_logical_device"
argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_values_from_function"
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_make_numpy_dataset"
argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_replicate_to_logical_devices"
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "experimental_split_to_logical_devices"
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "make_dataset_iterator"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_input_fn_iterator"
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "unwrap"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update_config_proto"
argspec: "args=[\'self\', \'config_proto\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -72,6 +72,10 @@ tf_module {
name: "StrategyExtended"
mtype: "<type \'type\'>"
}
member {
name: "TPUStrategy"
mtype: "<type \'type\'>"
}
member {
name: "cluster_resolver"
mtype: "<type \'module\'>"

View File

@ -883,7 +883,7 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
contrib_tpu_strategy_warning = (
ast_edits.ERROR,
"(Manual edit required) tf.contrib.distribute.TPUStrategy has "
"been migrated to tf.distribute.experimental.TPUStrategy. Note the "
"been migrated to tf.distribute.TPUStrategy. Note the "
"slight changes in constructor. " + distribute_strategy_api_changes)
contrib_collective_strategy_warning = (

View File

@ -2155,7 +2155,7 @@ def _log_prob(self, x):
expected = "tf.contrib.distribute.TPUStrategy"
_, _, errors, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
self.assertIn("migrated to tf.distribute.experimental.TPUStrategy",
self.assertIn("migrated to tf.distribute.TPUStrategy",
errors[0])
text = "tf.contrib.distribute.foo"