Make DistributedValues an exported TF symbol.
A DistributedValues subclass is returned by multiple methods in distribution strategy. PiperOrigin-RevId: 296966376 Change-Id: Id9674ef903e34354141c222b8e2196f5b46e4f42
This commit is contained in:
parent
411185a8fe
commit
ead15c719f
@ -43,6 +43,7 @@ from tensorflow.python.training.saving import saveable_object
|
|||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
def _get_current_replica_id_as_int():
|
def _get_current_replica_id_as_int():
|
||||||
@ -57,10 +58,73 @@ def _get_current_replica_id_as_int():
|
|||||||
return replica_id
|
return replica_id
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.DistributedValues", v1=[])
|
||||||
class DistributedValues(object):
|
class DistributedValues(object):
|
||||||
"""Holds a map from replica to values. Either PerReplica or Mirrored."""
|
"""Base class for representing distributed values.
|
||||||
|
|
||||||
|
A subclass instance of DistributedValues is created when creating variables
|
||||||
|
within a distribution strategy, iterating a `tf.Dataset` or through
|
||||||
|
`strategy.experimental_run_v2`. This base class should never be instantiated
|
||||||
|
directly. DistributedValues contains a value per replica. Depending on
|
||||||
|
the subclass, the values could either be synced on update, synced on demand,
|
||||||
|
or never synced.
|
||||||
|
|
||||||
|
DistributedValues can be reduced to obtain single value across replicas,
|
||||||
|
as input into `experimental_run_v2` or the per replica values inspected
|
||||||
|
using `experimental_local_results`.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
1. Created from Dataset:
|
||||||
|
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy()
|
||||||
|
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||||
|
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||||
|
>>> distributed_values = next(dataset_iterator)
|
||||||
|
|
||||||
|
2. Returned by `experimental_run_v2`:
|
||||||
|
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy()
|
||||||
|
>>> @tf.function
|
||||||
|
... def run():
|
||||||
|
... ctx = tf.distribute.get_replica_context()
|
||||||
|
... return ctx.replica_id_in_sync_group
|
||||||
|
>>> distributed_values = strategy.experimental_run_v2(run)
|
||||||
|
|
||||||
|
3. As input into `experimental_run_v2`:
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy()
|
||||||
|
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||||
|
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||||
|
>>> distributed_values = next(dataset_iterator)
|
||||||
|
>>> @tf.function
|
||||||
|
... def run(input):
|
||||||
|
... return input + 1.0
|
||||||
|
>>> updated_value = strategy.experimental_run_v2(run,
|
||||||
|
... args=(distributed_values,))
|
||||||
|
|
||||||
|
4. Reduce value
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy()
|
||||||
|
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||||
|
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||||
|
>>> distributed_values = next(dataset_iterator)
|
||||||
|
>>> reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
|
||||||
|
... distributed_values,
|
||||||
|
... axis = 0)
|
||||||
|
|
||||||
|
5. Inspect per replica values.
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy()
|
||||||
|
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||||
|
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||||
|
>>> per_replica_values = strategy.experimental_local_results(
|
||||||
|
... distributed_values)
|
||||||
|
>>> per_replica_values
|
||||||
|
(<tf.Tensor: shape=(2,), dtype=float32,
|
||||||
|
numpy=array([5., 6.], dtype=float32)>,)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, values):
|
def __init__(self, values):
|
||||||
|
"""Should only be called by subclass __init__."""
|
||||||
self._values = tuple(values)
|
self._values = tuple(values)
|
||||||
|
|
||||||
def _get(self):
|
def _get(self):
|
||||||
|
@ -0,0 +1,9 @@
|
|||||||
|
path: "tensorflow.distribute.DistributedValues"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.distribute.values.DistributedValues\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'values\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,10 @@ tf_module {
|
|||||||
name: "CrossDeviceOps"
|
name: "CrossDeviceOps"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "DistributedValues"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "HierarchicalCopyAllReduce"
|
name: "HierarchicalCopyAllReduce"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user