Make DistributedValues an exported TF symbol.
A DistributedValues subclass is returned by multiple methods in distribution strategy. PiperOrigin-RevId: 296966376 Change-Id: Id9674ef903e34354141c222b8e2196f5b46e4f42
This commit is contained in:
parent
411185a8fe
commit
ead15c719f
@ -43,6 +43,7 @@ from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def _get_current_replica_id_as_int():
|
||||
@ -57,10 +58,73 @@ def _get_current_replica_id_as_int():
|
||||
return replica_id
|
||||
|
||||
|
||||
@tf_export("distribute.DistributedValues", v1=[])
|
||||
class DistributedValues(object):
|
||||
"""Holds a map from replica to values. Either PerReplica or Mirrored."""
|
||||
"""Base class for representing distributed values.
|
||||
|
||||
A subclass instance of DistributedValues is created when creating variables
|
||||
within a distribution strategy, iterating a `tf.Dataset` or through
|
||||
`strategy.experimental_run_v2`. This base class should never be instantiated
|
||||
directly. DistributedValues contains a value per replica. Depending on
|
||||
the subclass, the values could either be synced on update, synced on demand,
|
||||
or never synced.
|
||||
|
||||
DistributedValues can be reduced to obtain single value across replicas,
|
||||
as input into `experimental_run_v2` or the per replica values inspected
|
||||
using `experimental_local_results`.
|
||||
|
||||
Example usage:
|
||||
|
||||
1. Created from Dataset:
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
>>> distributed_values = next(dataset_iterator)
|
||||
|
||||
2. Returned by `experimental_run_v2`:
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> @tf.function
|
||||
... def run():
|
||||
... ctx = tf.distribute.get_replica_context()
|
||||
... return ctx.replica_id_in_sync_group
|
||||
>>> distributed_values = strategy.experimental_run_v2(run)
|
||||
|
||||
3. As input into `experimental_run_v2`:
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
>>> distributed_values = next(dataset_iterator)
|
||||
>>> @tf.function
|
||||
... def run(input):
|
||||
... return input + 1.0
|
||||
>>> updated_value = strategy.experimental_run_v2(run,
|
||||
... args=(distributed_values,))
|
||||
|
||||
4. Reduce value
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
>>> distributed_values = next(dataset_iterator)
|
||||
>>> reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
|
||||
... distributed_values,
|
||||
... axis = 0)
|
||||
|
||||
5. Inspect per replica values.
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
>>> per_replica_values = strategy.experimental_local_results(
|
||||
... distributed_values)
|
||||
>>> per_replica_values
|
||||
(<tf.Tensor: shape=(2,), dtype=float32,
|
||||
numpy=array([5., 6.], dtype=float32)>,)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, values):
|
||||
"""Should only be called by subclass __init__."""
|
||||
self._values = tuple(values)
|
||||
|
||||
def _get(self):
|
||||
|
@ -0,0 +1,9 @@
|
||||
path: "tensorflow.distribute.DistributedValues"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.values.DistributedValues\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'values\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -4,6 +4,10 @@ tf_module {
|
||||
name: "CrossDeviceOps"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DistributedValues"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "HierarchicalCopyAllReduce"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user