tf.distribute clean ups.
PiperOrigin-RevId: 255533911
This commit is contained in:
parent
a6a6e38f45
commit
e0d9dfd54b
@ -23,11 +23,9 @@ from tensorflow.python.distribute import input_lib
|
|||||||
from tensorflow.python.distribute import mirrored_strategy
|
from tensorflow.python.distribute import mirrored_strategy
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access,invalid-name
|
|
||||||
all_local_devices = mirrored_strategy.all_local_devices
|
all_local_devices = mirrored_strategy.all_local_devices
|
||||||
CoreMirroredStrategy = mirrored_strategy.MirroredStrategy
|
CoreMirroredStrategy = mirrored_strategy.MirroredStrategy
|
||||||
CoreMirroredExtended = mirrored_strategy.MirroredExtended
|
CoreMirroredExtended = mirrored_strategy.MirroredExtended
|
||||||
# pylint: enable=protected-access,invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class MirroredStrategy(distribute_lib.StrategyV1):
|
class MirroredStrategy(distribute_lib.StrategyV1):
|
||||||
|
@ -282,9 +282,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
self._communication)
|
self._communication)
|
||||||
|
|
||||||
def _get_variable_creator_initial_value(self,
|
def _get_variable_creator_initial_value(self,
|
||||||
replica_id=0,
|
replica_id,
|
||||||
device=None,
|
device,
|
||||||
primary_var=None,
|
primary_var,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
if replica_id == 0: # First replica on each worker.
|
if replica_id == 0: # First replica on each worker.
|
||||||
assert device is not None
|
assert device is not None
|
||||||
|
@ -442,9 +442,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
|
self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
|
||||||
|
|
||||||
def _get_variable_creator_initial_value(self,
|
def _get_variable_creator_initial_value(self,
|
||||||
replica_id=0,
|
replica_id,
|
||||||
device=None,
|
device,
|
||||||
primary_var=None,
|
primary_var,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Return the initial value for variables on a replica."""
|
"""Return the initial value for variables on a replica."""
|
||||||
if replica_id == 0:
|
if replica_id == 0:
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Multi-GPU tests for MirroredStrategy."""
|
"""Tests for MirroredStrategy."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -298,7 +298,7 @@ class WorkerDeviceMap(DeviceMap):
|
|||||||
|
|
||||||
|
|
||||||
class DistributedValues(object):
|
class DistributedValues(object):
|
||||||
"""Holds a map from device to values. Either PerReplica or Mirrored."""
|
"""Holds a map from replica to values. Either PerReplica or Mirrored."""
|
||||||
|
|
||||||
def __init__(self, device_map, values, logical_device=None):
|
def __init__(self, device_map, values, logical_device=None):
|
||||||
assert isinstance(device_map, DeviceMap)
|
assert isinstance(device_map, DeviceMap)
|
||||||
@ -463,7 +463,7 @@ class DistributedDelegate(DistributedValues):
|
|||||||
|
|
||||||
|
|
||||||
class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
|
class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
|
||||||
"""Holds a map from device to unsynchronized values."""
|
"""Holds a map from replica to unsynchronized values."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _type_spec(self):
|
def _type_spec(self):
|
||||||
@ -536,7 +536,7 @@ class PerReplicaSpec(type_spec.TypeSpec):
|
|||||||
# DistributedDelegate and so can be used directly in cross-replica mode.
|
# DistributedDelegate and so can be used directly in cross-replica mode.
|
||||||
# TODO(tomhennigan) Should this extend CompositeTensor?
|
# TODO(tomhennigan) Should this extend CompositeTensor?
|
||||||
class Mirrored(DistributedDelegate):
|
class Mirrored(DistributedDelegate):
|
||||||
"""Holds a map from device to values which are kept in sync."""
|
"""Holds a map from replica to values which are kept in sync."""
|
||||||
|
|
||||||
def _get_cross_replica(self):
|
def _get_cross_replica(self):
|
||||||
device = device_util.canonicalize(device_util.current())
|
device = device_util.canonicalize(device_util.current())
|
||||||
@ -595,7 +595,7 @@ DistributedVarOp = collections.namedtuple(
|
|||||||
|
|
||||||
|
|
||||||
class DistributedVariable(DistributedDelegate, variables_lib.AbstractVariable):
|
class DistributedVariable(DistributedDelegate, variables_lib.AbstractVariable):
|
||||||
"""Holds a map from device to variables."""
|
"""Holds a map from replica to variables."""
|
||||||
# TODO(josh11b): Support changing the set of variables if e.g. if new
|
# TODO(josh11b): Support changing the set of variables if e.g. if new
|
||||||
# devices are joining or a device is to leave.
|
# devices are joining or a device is to leave.
|
||||||
|
|
||||||
@ -968,7 +968,7 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
|
|||||||
|
|
||||||
|
|
||||||
class MirroredVariable(DistributedVariable, Mirrored):
|
class MirroredVariable(DistributedVariable, Mirrored):
|
||||||
"""Holds a map from device to variables whose values are kept in sync."""
|
"""Holds a map from replica to variables whose values are kept in sync."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, strategy, device_map, values, aggregation, logical_device=None):
|
self, strategy, device_map, values, aggregation, logical_device=None):
|
||||||
@ -1094,7 +1094,7 @@ def is_distributed_variable(v):
|
|||||||
|
|
||||||
|
|
||||||
class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
|
class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
|
||||||
"""Holds a map from device to TPU variables whose values are kept in sync."""
|
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
||||||
|
|
||||||
def _assign_func(self, *args, **kwargs):
|
def _assign_func(self, *args, **kwargs):
|
||||||
with _enter_or_assert_strategy(self._distribute_strategy):
|
with _enter_or_assert_strategy(self._distribute_strategy):
|
||||||
@ -1158,7 +1158,7 @@ def _assert_replica_context(strategy):
|
|||||||
|
|
||||||
|
|
||||||
class SyncOnReadVariable(DistributedVariable, PerReplica):
|
class SyncOnReadVariable(DistributedVariable, PerReplica):
|
||||||
"""Holds a map from device to variables whose values are reduced on save."""
|
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, strategy, device_map, values, aggregation, logical_device=None):
|
self, strategy, device_map, values, aggregation, logical_device=None):
|
||||||
@ -1255,7 +1255,7 @@ ops.register_tensor_conversion_function(SyncOnReadVariable,
|
|||||||
|
|
||||||
|
|
||||||
class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
|
class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
|
||||||
"""Holds a map from device to variables whose values are reduced on save."""
|
"""Holds a map from replica to variables whose values are reduced on save."""
|
||||||
|
|
||||||
def assign_sub(self, *args, **kwargs):
|
def assign_sub(self, *args, **kwargs):
|
||||||
if _enclosing_tpu_context() is None:
|
if _enclosing_tpu_context() is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user