Make DistributedValues an exported TF symbol.

A DistributedValues subclass is returned by multiple methods in distribution strategy.

PiperOrigin-RevId: 296966376
Change-Id: Id9674ef903e34354141c222b8e2196f5b46e4f42
This commit is contained in:
Ken Franko 2020-02-24 13:53:44 -08:00 committed by TensorFlower Gardener
parent 411185a8fe
commit ead15c719f
3 changed files with 78 additions and 1 deletions

View File

@ -43,6 +43,7 @@ from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
def _get_current_replica_id_as_int():
@ -57,10 +58,73 @@ def _get_current_replica_id_as_int():
return replica_id
@tf_export("distribute.DistributedValues", v1=[])
class DistributedValues(object):
"""Holds a map from replica to values. Either PerReplica or Mirrored."""
"""Base class for representing distributed values.
A subclass instance of DistributedValues is created when creating variables
within a distribution strategy, iterating a `tf.Dataset` or through
`strategy.experimental_run_v2`. This base class should never be instantiated
directly. DistributedValues contains a value per replica. Depending on
the subclass, the values could either be synced on update, synced on demand,
or never synced.
DistributedValues can be reduced to obtain single value across replicas,
as input into `experimental_run_v2` or the per replica values inspected
using `experimental_local_results`.
Example usage:
1. Created from Dataset:
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> distributed_values = next(dataset_iterator)
2. Returned by `experimental_run_v2`:
>>> strategy = tf.distribute.MirroredStrategy()
>>> @tf.function
... def run():
... ctx = tf.distribute.get_replica_context()
... return ctx.replica_id_in_sync_group
>>> distributed_values = strategy.experimental_run_v2(run)
3. As input into `experimental_run_v2`:
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> distributed_values = next(dataset_iterator)
>>> @tf.function
... def run(input):
... return input + 1.0
>>> updated_value = strategy.experimental_run_v2(run,
... args=(distributed_values,))
4. Reduce value
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> distributed_values = next(dataset_iterator)
>>> reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
... distributed_values,
... axis = 0)
5. Inspect per replica values.
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> per_replica_values = strategy.experimental_local_results(
... distributed_values)
>>> per_replica_values
(<tf.Tensor: shape=(2,), dtype=float32,
numpy=array([5., 6.], dtype=float32)>,)
"""
def __init__(self, values):
"""Should only be called by subclass __init__."""
self._values = tuple(values)
def _get(self):

View File

@ -0,0 +1,9 @@
path: "tensorflow.distribute.DistributedValues"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.values.DistributedValues\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'values\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -4,6 +4,10 @@ tf_module {
name: "CrossDeviceOps"
mtype: "<type \'type\'>"
}
member {
name: "DistributedValues"
mtype: "<type \'type\'>"
}
member {
name: "HierarchicalCopyAllReduce"
mtype: "<type \'type\'>"