Graduate TPUStrategy from experimental.
RELNOTES=Make TPUStrategy symbol non experimental. PiperOrigin-RevId: 317482072 Change-Id: I8bf596729699cb02fa275dfb63855c2dc68c1d42
This commit is contained in:
parent
e647a3b425
commit
39a2286a0e
@ -422,7 +422,7 @@ def enable_check_numerics(stack_height_limit=30,
|
|||||||
tf.debugging.enable_check_numerics()
|
tf.debugging.enable_check_numerics()
|
||||||
|
|
||||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
strategy = tf.distribute.TPUStrategy(resolver)
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
# ...
|
# ...
|
||||||
```
|
```
|
||||||
|
@ -737,7 +737,7 @@ def enable_dump_debug_info(dump_root,
|
|||||||
logdir, tensor_debug_mode="FULL_HEALTH")
|
logdir, tensor_debug_mode="FULL_HEALTH")
|
||||||
|
|
||||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
strategy = tf.distribute.TPUStrategy(resolver)
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
# ...
|
# ...
|
||||||
```
|
```
|
||||||
|
@ -49,7 +49,7 @@ model.evaluate(dataset)
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# Create the strategy instance.
|
# Create the strategy instance.
|
||||||
tpu_strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
tpu_strategy = tf.distribute.TPUStrategy(resolver)
|
||||||
|
|
||||||
|
|
||||||
# Create the keras model under strategy.scope()
|
# Create the keras model under strategy.scope()
|
||||||
|
@ -618,7 +618,7 @@ class InputOptions(
|
|||||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||||
tf.config.experimental_connect_to_cluster(resolver)
|
tf.config.experimental_connect_to_cluster(resolver)
|
||||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
strategy = tf.distribute.TPUStrategy(resolver)
|
||||||
|
|
||||||
dataset = tf.data.Dataset.range(16)
|
dataset = tf.data.Dataset.range(16)
|
||||||
distributed_dataset_on_host = (
|
distributed_dataset_on_host = (
|
||||||
@ -1462,17 +1462,17 @@ class Strategy(StrategyBase):
|
|||||||
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
||||||
topology,
|
topology,
|
||||||
computation_shape=[1, 1, 2],
|
computation_shape=[1, 1, 1, 2],
|
||||||
num_replicas=4)
|
num_replicas=4)
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(
|
strategy = tf.distribute.TPUStrategy(
|
||||||
resolver, device_assignment=device_assignment)
|
resolver, experimental_device_assignment=device_assignment)
|
||||||
iterator = iter(inputs)
|
iterator = iter(inputs)
|
||||||
|
|
||||||
@tf.function()
|
@tf.function()
|
||||||
def step_fn(inputs):
|
def step_fn(inputs):
|
||||||
output = tf.add(inputs, inputs)
|
output = tf.add(inputs, inputs)
|
||||||
|
|
||||||
// Add operation will be executed on logical device 0.
|
# Add operation will be executed on logical device 0.
|
||||||
output = strategy.experimental_assign_to_logical_device(output, 0)
|
output = strategy.experimental_assign_to_logical_device(output, 0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -1517,10 +1517,10 @@ class Strategy(StrategyBase):
|
|||||||
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
||||||
topology,
|
topology,
|
||||||
computation_shape=[2, 2, 2],
|
computation_shape=[1, 2, 2, 2],
|
||||||
num_replicas=1)
|
num_replicas=1)
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(
|
strategy = tf.distribute.TPUStrategy(
|
||||||
resolver, device_assignment=device_assignment)
|
resolver, experimental_device_assignment=device_assignment)
|
||||||
|
|
||||||
iterator = iter(inputs)
|
iterator = iter(inputs)
|
||||||
|
|
||||||
@ -1529,8 +1529,8 @@ class Strategy(StrategyBase):
|
|||||||
inputs = strategy.experimental_split_to_logical_devices(
|
inputs = strategy.experimental_split_to_logical_devices(
|
||||||
inputs, [1, 2, 4, 1])
|
inputs, [1, 2, 4, 1])
|
||||||
|
|
||||||
// model() function will be executed on 8 logical devices with `inputs`
|
# model() function will be executed on 8 logical devices with `inputs`
|
||||||
// split 2 * 4 ways.
|
# split 2 * 4 ways.
|
||||||
output = model(inputs)
|
output = model(inputs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -1571,10 +1571,10 @@ class Strategy(StrategyBase):
|
|||||||
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
||||||
topology,
|
topology,
|
||||||
computation_shape=[1, 1, 2],
|
computation_shape=[1, 1, 1, 2],
|
||||||
num_replicas=4)
|
num_replicas=4)
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(
|
strategy = tf.distribute.TPUStrategy(
|
||||||
resolver, device_assignment=device_assignment)
|
resolver, experimental_device_assignment=device_assignment)
|
||||||
|
|
||||||
iterator = iter(inputs)
|
iterator = iter(inputs)
|
||||||
|
|
||||||
@ -1584,12 +1584,12 @@ class Strategy(StrategyBase):
|
|||||||
images = strategy.experimental_split_to_logical_devices(
|
images = strategy.experimental_split_to_logical_devices(
|
||||||
inputs, [1, 2, 4, 1])
|
inputs, [1, 2, 4, 1])
|
||||||
|
|
||||||
// model() function will be executed on 8 logical devices with `inputs`
|
# model() function will be executed on 8 logical devices with `inputs`
|
||||||
// split 2 * 4 ways.
|
# split 2 * 4 ways.
|
||||||
output = model(inputs)
|
output = model(inputs)
|
||||||
|
|
||||||
// For loss calculation, all logical devices share the same logits
|
# For loss calculation, all logical devices share the same logits
|
||||||
// and labels.
|
# and labels.
|
||||||
labels = strategy.experimental_replicate_to_logical_devices(labels)
|
labels = strategy.experimental_replicate_to_logical_devices(labels)
|
||||||
output = strategy.experimental_replicate_to_logical_devices(output)
|
output = strategy.experimental_replicate_to_logical_devices(output)
|
||||||
loss = loss_fn(labels, output)
|
loss = loss_fn(labels, output)
|
||||||
|
@ -190,9 +190,8 @@ class MirroredStrategy(distribute_lib.Strategy):
|
|||||||
|
|
||||||
This strategy is typically used for training on one
|
This strategy is typically used for training on one
|
||||||
machine with multiple GPUs. For TPUs, use
|
machine with multiple GPUs. For TPUs, use
|
||||||
`tf.distribute.experimental.TPUStrategy`. To use `MirroredStrategy` with
|
`tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers,
|
||||||
multiple workers, please refer to
|
please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
|
||||||
|
|
||||||
For example, a variable created under a `MirroredStrategy` is a
|
For example, a variable created under a `MirroredStrategy` is a
|
||||||
`MirroredVariable`. If no devices are specified in the constructor argument of
|
`MirroredVariable`. If no devices are specified in the constructor argument of
|
||||||
|
@ -23,6 +23,7 @@ import collections
|
|||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import weakref
|
import weakref
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -57,6 +58,7 @@ from tensorflow.python.tpu import tpu
|
|||||||
from tensorflow.python.tpu import tpu_strategy_util
|
from tensorflow.python.tpu import tpu_strategy_util
|
||||||
from tensorflow.python.tpu import training_loop
|
from tensorflow.python.tpu import training_loop
|
||||||
from tensorflow.python.tpu.ops import tpu_ops
|
from tensorflow.python.tpu.ops import tpu_ops
|
||||||
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -97,9 +99,188 @@ def validate_run_function(fn):
|
|||||||
"eager behavior is enabled.")
|
"eager behavior is enabled.")
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.TPUStrategy", v1=[])
|
||||||
|
class TPUStrategyV2(distribute_lib.Strategy):
|
||||||
|
"""Synchronous training on TPUs and TPU Pods.
|
||||||
|
|
||||||
|
To construct a TPUStrategy object, you need to run the
|
||||||
|
initialization code as below:
|
||||||
|
|
||||||
|
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||||
|
>>> tf.config.experimental_connect_to_cluster(resolver)
|
||||||
|
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
|
>>> strategy = tf.distribute.TPUStrategy(resolver)
|
||||||
|
|
||||||
|
While using distribution strategies, the variables created within the
|
||||||
|
strategy's scope will be replicated across all the replicas and can be kept in
|
||||||
|
sync using all-reduce algorithms.
|
||||||
|
|
||||||
|
To run TF2 programs on TPUs, you can either use `.compile` and
|
||||||
|
`.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
|
||||||
|
training loop by calling `strategy.run` directly. Note that
|
||||||
|
TPUStrategy doesn't support pure eager execution, so please make sure the
|
||||||
|
function passed into `strategy.run` is a `tf.function` or
|
||||||
|
`strategy.run` is called inside a `tf.function` if eager
|
||||||
|
behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.
|
||||||
|
|
||||||
|
`experimental_distribute_datasets_from_function` and
|
||||||
|
`experimental_distribute_dataset` APIs can be used to distribute the dataset
|
||||||
|
across the TPU workers when writing your own training loop. If you are using
|
||||||
|
`fit` and `compile` methods available in `tf.keras.Model`, then Keras will
|
||||||
|
handle the distribution for you.
|
||||||
|
|
||||||
|
An example of writing customized training loop on TPUs:
|
||||||
|
|
||||||
|
>>> with strategy.scope():
|
||||||
|
... model = tf.keras.Sequential([
|
||||||
|
... tf.keras.layers.Dense(2, input_shape=(5,)),
|
||||||
|
... ])
|
||||||
|
... optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
|
||||||
|
|
||||||
|
>>> def dataset_fn(ctx):
|
||||||
|
... x = np.random.random((2, 5)).astype(np.float32)
|
||||||
|
... y = np.random.randint(2, size=(2, 1))
|
||||||
|
... dataset = tf.data.Dataset.from_tensor_slices((x, y))
|
||||||
|
... return dataset.repeat().batch(1, drop_remainder=True)
|
||||||
|
>>> dist_dataset = strategy.experimental_distribute_datasets_from_function(
|
||||||
|
... dataset_fn)
|
||||||
|
>>> iterator = iter(dist_dataset)
|
||||||
|
|
||||||
|
>>> @tf.function()
|
||||||
|
... def train_step(iterator):
|
||||||
|
...
|
||||||
|
... def step_fn(inputs):
|
||||||
|
... features, labels = inputs
|
||||||
|
... with tf.GradientTape() as tape:
|
||||||
|
... logits = model(features, training=True)
|
||||||
|
... loss = tf.keras.losses.sparse_categorical_crossentropy(
|
||||||
|
... labels, logits)
|
||||||
|
...
|
||||||
|
... grads = tape.gradient(loss, model.trainable_variables)
|
||||||
|
... optimizer.apply_gradients(zip(grads, model.trainable_variables))
|
||||||
|
...
|
||||||
|
... strategy.run(step_fn, args=(next(iterator),))
|
||||||
|
|
||||||
|
>>> train_step(iterator)
|
||||||
|
|
||||||
|
For the advanced use cases like model parallelism, you can set
|
||||||
|
`experimental_device_assignment` argument when creating TPUStrategy to specify
|
||||||
|
number of replicas and number of logical devices. Below is an example to
|
||||||
|
initialize TPU system with 2 logical devices and 1 replica.
|
||||||
|
|
||||||
|
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||||
|
>>> tf.config.experimental_connect_to_cluster(resolver)
|
||||||
|
>>> topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
|
>>> device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
||||||
|
... topology,
|
||||||
|
... computation_shape=[1, 1, 1, 2],
|
||||||
|
... num_replicas=1)
|
||||||
|
>>> strategy = tf.distribute.TPUStrategy(
|
||||||
|
... resolver, experimental_device_assignment=device_assignment)
|
||||||
|
|
||||||
|
Then you can run a `tf.add` operation only on logical device 0.
|
||||||
|
|
||||||
|
>>> @tf.function()
|
||||||
|
... def step_fn(inputs):
|
||||||
|
... features, _ = inputs
|
||||||
|
... output = tf.add(features, features)
|
||||||
|
...
|
||||||
|
... # Add operation will be executed on logical device 0.
|
||||||
|
... output = strategy.experimental_assign_to_logical_device(output, 0)
|
||||||
|
... return output
|
||||||
|
>>> dist_dataset = strategy.experimental_distribute_datasets_from_function(
|
||||||
|
... dataset_fn)
|
||||||
|
>>> iterator = iter(dist_dataset)
|
||||||
|
>>> strategy.run(step_fn, args=(next(iterator),))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
tpu_cluster_resolver=None,
|
||||||
|
experimental_device_assignment=None):
|
||||||
|
"""Synchronous training in TPU donuts or Pods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
|
||||||
|
which provides information about the TPU cluster. If None, it will
|
||||||
|
assume running on a local TPU worker.
|
||||||
|
experimental_device_assignment: Optional
|
||||||
|
`tf.tpu.experimental.DeviceAssignment` to specify the placement of
|
||||||
|
replicas on the TPU cluster.
|
||||||
|
"""
|
||||||
|
super(TPUStrategyV2, self).__init__(TPUExtended(
|
||||||
|
self, tpu_cluster_resolver,
|
||||||
|
device_assignment=experimental_device_assignment))
|
||||||
|
distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
|
||||||
|
distribute_lib.distribution_strategy_replica_gauge.get_cell(
|
||||||
|
"num_workers").set(self.extended.num_hosts)
|
||||||
|
distribute_lib.distribution_strategy_replica_gauge.get_cell(
|
||||||
|
"num_replicas_per_worker").set(self.extended.num_replicas_per_host)
|
||||||
|
# Packed variable is used to reduce the overhead of function execution.
|
||||||
|
# For a DistributedVariable, only one variable handle is captured into a
|
||||||
|
# function graph. It's only supported in eager mode.
|
||||||
|
self._enable_packed_variable_in_eager_mode = False
|
||||||
|
|
||||||
|
def run(self, fn, args=(), kwargs=None, options=None):
|
||||||
|
"""Run the computation defined by `fn` on each TPU replica.
|
||||||
|
|
||||||
|
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
|
||||||
|
`tf.distribute.DistributedValues`, such as those produced by a
|
||||||
|
`tf.distribute.DistributedDataset` from
|
||||||
|
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
||||||
|
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`,
|
||||||
|
when `fn` is executed on a particular replica, it will be executed with the
|
||||||
|
component of `tf.distribute.DistributedValues` that correspond to that
|
||||||
|
replica.
|
||||||
|
|
||||||
|
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
||||||
|
as `all_reduce`.
|
||||||
|
|
||||||
|
All arguments in `args` or `kwargs` should either be nest of tensors or
|
||||||
|
`tf.distribute.DistributedValues` containing tensors or composite tensors.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||||
|
>>> tf.config.experimental_connect_to_cluster(resolver)
|
||||||
|
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
|
>>> strategy = tf.distribute.TPUStrategy(resolver)
|
||||||
|
>>> @tf.function
|
||||||
|
... def run():
|
||||||
|
... def value_fn(value_context):
|
||||||
|
... return value_context.num_replicas_in_sync
|
||||||
|
... distributed_values = (
|
||||||
|
... strategy.experimental_distribute_values_from_function(value_fn))
|
||||||
|
... def replica_fn(input):
|
||||||
|
... return input * 2
|
||||||
|
... return strategy.run(replica_fn, args=(distributed_values,))
|
||||||
|
>>> result = run()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
|
||||||
|
args: (Optional) Positional arguments to `fn`.
|
||||||
|
kwargs: (Optional) Keyword arguments to `fn`.
|
||||||
|
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
|
||||||
|
the options to run `fn`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged return value of `fn` across replicas. The structure of the return
|
||||||
|
value is the same as the return value from `fn`. Each element in the
|
||||||
|
structure can either be `tf.distribute.DistributedValues`, `Tensor`
|
||||||
|
objects, or `Tensor`s (for example, if running on a single replica).
|
||||||
|
"""
|
||||||
|
validate_run_function(fn)
|
||||||
|
|
||||||
|
# Note: the target function is converted to graph even when in Eager mode,
|
||||||
|
# so autograph is on by default here.
|
||||||
|
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
|
||||||
|
options = options or distribute_lib.RunOptions()
|
||||||
|
return self.extended.tpu_run(fn, args, kwargs, options)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("distribute.experimental.TPUStrategy", v1=[])
|
@tf_export("distribute.experimental.TPUStrategy", v1=[])
|
||||||
|
@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy")
|
||||||
class TPUStrategy(distribute_lib.Strategy):
|
class TPUStrategy(distribute_lib.Strategy):
|
||||||
"""TPU distribution strategy implementation.
|
"""Synchronous training on TPUs and TPU Pods.
|
||||||
|
|
||||||
To construct a TPUStrategy object, you need to run the
|
To construct a TPUStrategy object, you need to run the
|
||||||
initialization code as below:
|
initialization code as below:
|
||||||
@ -109,9 +290,9 @@ class TPUStrategy(distribute_lib.Strategy):
|
|||||||
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
|
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
>>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
>>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
||||||
|
|
||||||
While using distribution strategies, the variables created within strategy's
|
While using distribution strategies, the variables created within the
|
||||||
scope will be replicated across all the replicas and can be kept in sync
|
strategy's scope will be replicated across all the replicas and can be kept in
|
||||||
using all-reduce algorithms.
|
sync using all-reduce algorithms.
|
||||||
|
|
||||||
To run TF2 programs on TPUs, you can either use `.compile` and
|
To run TF2 programs on TPUs, you can either use `.compile` and
|
||||||
`.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
|
`.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
|
||||||
@ -131,9 +312,12 @@ class TPUStrategy(distribute_lib.Strategy):
|
|||||||
tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
|
tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
|
||||||
which provides information about the TPU cluster.
|
which provides information about the TPU cluster.
|
||||||
device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
|
device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
|
||||||
specify the placement of replicas on the TPU cluster. Currently only
|
specify the placement of replicas on the TPU cluster.
|
||||||
supports the usecase of using a single core within a TPU cluster.
|
|
||||||
"""
|
"""
|
||||||
|
logging.warning(
|
||||||
|
"`tf.distribute.experimental.TPUStrategy` is deprecated, please use "
|
||||||
|
" the non experimental symbol `tf.distribute.TPUStrategy` instead.")
|
||||||
|
|
||||||
super(TPUStrategy, self).__init__(TPUExtended(
|
super(TPUStrategy, self).__init__(TPUExtended(
|
||||||
self, tpu_cluster_resolver, device_assignment=device_assignment))
|
self, tpu_cluster_resolver, device_assignment=device_assignment))
|
||||||
distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
|
distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
|
||||||
|
@ -132,7 +132,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
|||||||
First lets look at the `TPUStrategy` mode. Initial setup looks like:
|
First lets look at the `TPUStrategy` mode. Initial setup looks like:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(...)
|
strategy = tf.distribute.TPUStrategy(...)
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
||||||
feature_config=feature_config,
|
feature_config=feature_config,
|
||||||
@ -234,7 +234,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
|||||||
"""Creates the TPUEmbedding mid level API object.
|
"""Creates the TPUEmbedding mid level API object.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(...)
|
strategy = tf.distribute.TPUStrategy(...)
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
||||||
feature_config=tf.tpu.experimental.embedding.FeatureConfig(
|
feature_config=tf.tpu.experimental.embedding.FeatureConfig(
|
||||||
@ -512,7 +512,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
|||||||
ensure you understand the effect of applying a zero gradient.
|
ensure you understand the effect of applying a zero gradient.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(...)
|
strategy = tf.distribute.TPUStrategy(...)
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
||||||
|
|
||||||
@ -603,7 +603,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
|||||||
`(batch_size, max_sequence_length, dim)` instead.
|
`(batch_size, max_sequence_length, dim)` instead.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(...)
|
strategy = tf.distribute.TPUStrategy(...)
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
||||||
|
|
||||||
@ -1054,13 +1054,13 @@ class TPUEmbedding(tracking.AutoTrackable):
|
|||||||
embedding tables. We expect that the batch size of each of the tensors in
|
embedding tables. We expect that the batch size of each of the tensors in
|
||||||
features matches the per core batch size. This will automatically happen if
|
features matches the per core batch size. This will automatically happen if
|
||||||
your input dataset is batched to the global batch size and you use
|
your input dataset is batched to the global batch size and you use
|
||||||
`tf.distribute.experimental.TPUStrategy`'s `experimental_distribute_dataset`
|
`tf.distribute.TPUStrategy`'s `experimental_distribute_dataset`
|
||||||
or if you use `experimental_distribute_datasets_from_function` and batch
|
or if you use `experimental_distribute_datasets_from_function` and batch
|
||||||
to the per core batch size computed by the context passed to your input
|
to the per core batch size computed by the context passed to your input
|
||||||
function.
|
function.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
strategy = tf.distribute.experimental.TPUStrategy(...)
|
strategy = tf.distribute.TPUStrategy(...)
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
|
||||||
|
|
||||||
|
@ -0,0 +1,99 @@
|
|||||||
|
path: "tensorflow.distribute.TPUStrategy"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.distribute.tpu_strategy.TPUStrategyV2\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "extended"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "num_replicas_in_sync"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'tpu_cluster_resolver\', \'experimental_device_assignment\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "colocate_vars_with"
|
||||||
|
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "configure"
|
||||||
|
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_assign_to_logical_device"
|
||||||
|
argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_distribute_dataset"
|
||||||
|
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_distribute_datasets_from_function"
|
||||||
|
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_distribute_values_from_function"
|
||||||
|
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_local_results"
|
||||||
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_make_numpy_dataset"
|
||||||
|
argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_replicate_to_logical_devices"
|
||||||
|
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_v2"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_split_to_logical_devices"
|
||||||
|
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "group"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "make_dataset_iterator"
|
||||||
|
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "make_input_fn_iterator"
|
||||||
|
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "reduce"
|
||||||
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "run"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scope"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unwrap"
|
||||||
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "update_config_proto"
|
||||||
|
argspec: "args=[\'self\', \'config_proto\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -72,6 +72,10 @@ tf_module {
|
|||||||
name: "StrategyExtended"
|
name: "StrategyExtended"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "TPUStrategy"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "cluster_resolver"
|
name: "cluster_resolver"
|
||||||
mtype: "<type \'module\'>"
|
mtype: "<type \'module\'>"
|
||||||
|
@ -883,7 +883,7 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
|
|||||||
contrib_tpu_strategy_warning = (
|
contrib_tpu_strategy_warning = (
|
||||||
ast_edits.ERROR,
|
ast_edits.ERROR,
|
||||||
"(Manual edit required) tf.contrib.distribute.TPUStrategy has "
|
"(Manual edit required) tf.contrib.distribute.TPUStrategy has "
|
||||||
"been migrated to tf.distribute.experimental.TPUStrategy. Note the "
|
"been migrated to tf.distribute.TPUStrategy. Note the "
|
||||||
"slight changes in constructor. " + distribute_strategy_api_changes)
|
"slight changes in constructor. " + distribute_strategy_api_changes)
|
||||||
|
|
||||||
contrib_collective_strategy_warning = (
|
contrib_collective_strategy_warning = (
|
||||||
|
@ -2155,7 +2155,7 @@ def _log_prob(self, x):
|
|||||||
expected = "tf.contrib.distribute.TPUStrategy"
|
expected = "tf.contrib.distribute.TPUStrategy"
|
||||||
_, _, errors, new_text = self._upgrade(text)
|
_, _, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(expected, new_text)
|
self.assertEqual(expected, new_text)
|
||||||
self.assertIn("migrated to tf.distribute.experimental.TPUStrategy",
|
self.assertIn("migrated to tf.distribute.TPUStrategy",
|
||||||
errors[0])
|
errors[0])
|
||||||
|
|
||||||
text = "tf.contrib.distribute.foo"
|
text = "tf.contrib.distribute.foo"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user