tf.distribute clean ups.
PiperOrigin-RevId: 255533911
This commit is contained in:
parent
a6a6e38f45
commit
e0d9dfd54b
@ -23,11 +23,9 @@ from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
|
||||
|
||||
# pylint: disable=protected-access,invalid-name
|
||||
all_local_devices = mirrored_strategy.all_local_devices
|
||||
CoreMirroredStrategy = mirrored_strategy.MirroredStrategy
|
||||
CoreMirroredExtended = mirrored_strategy.MirroredExtended
|
||||
# pylint: enable=protected-access,invalid-name
|
||||
|
||||
|
||||
class MirroredStrategy(distribute_lib.StrategyV1):
|
||||
|
@ -282,9 +282,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
self._communication)
|
||||
|
||||
def _get_variable_creator_initial_value(self,
|
||||
replica_id=0,
|
||||
device=None,
|
||||
primary_var=None,
|
||||
replica_id,
|
||||
device,
|
||||
primary_var,
|
||||
**kwargs):
|
||||
if replica_id == 0: # First replica on each worker.
|
||||
assert device is not None
|
||||
|
@ -442,9 +442,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
|
||||
|
||||
def _get_variable_creator_initial_value(self,
|
||||
replica_id=0,
|
||||
device=None,
|
||||
primary_var=None,
|
||||
replica_id,
|
||||
device,
|
||||
primary_var,
|
||||
**kwargs):
|
||||
"""Return the initial value for variables on a replica."""
|
||||
if replica_id == 0:
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Multi-GPU tests for MirroredStrategy."""
|
||||
"""Tests for MirroredStrategy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -298,7 +298,7 @@ class WorkerDeviceMap(DeviceMap):
|
||||
|
||||
|
||||
class DistributedValues(object):
|
||||
"""Holds a map from device to values. Either PerReplica or Mirrored."""
|
||||
"""Holds a map from replica to values. Either PerReplica or Mirrored."""
|
||||
|
||||
def __init__(self, device_map, values, logical_device=None):
|
||||
assert isinstance(device_map, DeviceMap)
|
||||
@ -463,7 +463,7 @@ class DistributedDelegate(DistributedValues):
|
||||
|
||||
|
||||
class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
|
||||
"""Holds a map from device to unsynchronized values."""
|
||||
"""Holds a map from replica to unsynchronized values."""
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
@ -536,7 +536,7 @@ class PerReplicaSpec(type_spec.TypeSpec):
|
||||
# DistributedDelegate and so can be used directly in cross-replica mode.
|
||||
# TODO(tomhennigan) Should this extend CompositeTensor?
|
||||
class Mirrored(DistributedDelegate):
|
||||
"""Holds a map from device to values which are kept in sync."""
|
||||
"""Holds a map from replica to values which are kept in sync."""
|
||||
|
||||
def _get_cross_replica(self):
|
||||
device = device_util.canonicalize(device_util.current())
|
||||
@ -595,7 +595,7 @@ DistributedVarOp = collections.namedtuple(
|
||||
|
||||
|
||||
class DistributedVariable(DistributedDelegate, variables_lib.AbstractVariable):
|
||||
"""Holds a map from device to variables."""
|
||||
"""Holds a map from replica to variables."""
|
||||
# TODO(josh11b): Support changing the set of variables if e.g. if new
|
||||
# devices are joining or a device is to leave.
|
||||
|
||||
@ -968,7 +968,7 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
|
||||
|
||||
|
||||
class MirroredVariable(DistributedVariable, Mirrored):
|
||||
"""Holds a map from device to variables whose values are kept in sync."""
|
||||
"""Holds a map from replica to variables whose values are kept in sync."""
|
||||
|
||||
def __init__(
|
||||
self, strategy, device_map, values, aggregation, logical_device=None):
|
||||
@ -1094,7 +1094,7 @@ def is_distributed_variable(v):
|
||||
|
||||
|
||||
class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
|
||||
"""Holds a map from device to TPU variables whose values are kept in sync."""
|
||||
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
||||
|
||||
def _assign_func(self, *args, **kwargs):
|
||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||
@ -1158,7 +1158,7 @@ def _assert_replica_context(strategy):
|
||||
|
||||
|
||||
class SyncOnReadVariable(DistributedVariable, PerReplica):
|
||||
"""Holds a map from device to variables whose values are reduced on save."""
|
||||
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||
|
||||
def __init__(
|
||||
self, strategy, device_map, values, aggregation, logical_device=None):
|
||||
@ -1255,7 +1255,7 @@ ops.register_tensor_conversion_function(SyncOnReadVariable,
|
||||
|
||||
|
||||
class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
|
||||
"""Holds a map from device to variables whose values are reduced on save."""
|
||||
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
if _enclosing_tpu_context() is None:
|
||||
|
Loading…
Reference in New Issue
Block a user