tf.distribute clean ups.

PiperOrigin-RevId: 255533911
This commit is contained in:
A. Unique TensorFlower 2019-06-27 21:01:27 -07:00 committed by TensorFlower Gardener
parent a6a6e38f45
commit e0d9dfd54b
5 changed files with 15 additions and 17 deletions

View File

@ -23,11 +23,9 @@ from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import mirrored_strategy
# pylint: disable=protected-access,invalid-name
all_local_devices = mirrored_strategy.all_local_devices
CoreMirroredStrategy = mirrored_strategy.MirroredStrategy
CoreMirroredExtended = mirrored_strategy.MirroredExtended
# pylint: enable=protected-access,invalid-name
class MirroredStrategy(distribute_lib.StrategyV1):

View File

@ -282,9 +282,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._communication)
def _get_variable_creator_initial_value(self,
replica_id=0,
device=None,
primary_var=None,
replica_id,
device,
primary_var,
**kwargs):
if replica_id == 0: # First replica on each worker.
assert device is not None

View File

@ -442,9 +442,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce()
def _get_variable_creator_initial_value(self,
replica_id=0,
device=None,
primary_var=None,
replica_id,
device,
primary_var,
**kwargs):
"""Return the initial value for variables on a replica."""
if replica_id == 0:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multi-GPU tests for MirroredStrategy."""
"""Tests for MirroredStrategy."""
from __future__ import absolute_import
from __future__ import division

View File

@ -298,7 +298,7 @@ class WorkerDeviceMap(DeviceMap):
class DistributedValues(object):
"""Holds a map from device to values. Either PerReplica or Mirrored."""
"""Holds a map from replica to values. Either PerReplica or Mirrored."""
def __init__(self, device_map, values, logical_device=None):
assert isinstance(device_map, DeviceMap)
@ -463,7 +463,7 @@ class DistributedDelegate(DistributedValues):
class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
"""Holds a map from device to unsynchronized values."""
"""Holds a map from replica to unsynchronized values."""
@property
def _type_spec(self):
@ -536,7 +536,7 @@ class PerReplicaSpec(type_spec.TypeSpec):
# DistributedDelegate and so can be used directly in cross-replica mode.
# TODO(tomhennigan) Should this extend CompositeTensor?
class Mirrored(DistributedDelegate):
"""Holds a map from device to values which are kept in sync."""
"""Holds a map from replica to values which are kept in sync."""
def _get_cross_replica(self):
device = device_util.canonicalize(device_util.current())
@ -595,7 +595,7 @@ DistributedVarOp = collections.namedtuple(
class DistributedVariable(DistributedDelegate, variables_lib.AbstractVariable):
"""Holds a map from device to variables."""
"""Holds a map from replica to variables."""
# TODO(josh11b): Support changing the set of variables if e.g. if new
# devices are joining or a device is to leave.
@ -968,7 +968,7 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
class MirroredVariable(DistributedVariable, Mirrored):
"""Holds a map from device to variables whose values are kept in sync."""
"""Holds a map from replica to variables whose values are kept in sync."""
def __init__(
self, strategy, device_map, values, aggregation, logical_device=None):
@ -1094,7 +1094,7 @@ def is_distributed_variable(v):
class TPUMirroredVariable(TPUVariableMixin, MirroredVariable):
"""Holds a map from device to TPU variables whose values are kept in sync."""
"""Holds a map from replica to TPU variables whose values are kept in sync."""
def _assign_func(self, *args, **kwargs):
with _enter_or_assert_strategy(self._distribute_strategy):
@ -1158,7 +1158,7 @@ def _assert_replica_context(strategy):
class SyncOnReadVariable(DistributedVariable, PerReplica):
"""Holds a map from device to variables whose values are reduced on save."""
"""Holds a map from replica to variables whose values are reduced on save."""
def __init__(
self, strategy, device_map, values, aggregation, logical_device=None):
@ -1255,7 +1255,7 @@ ops.register_tensor_conversion_function(SyncOnReadVariable,
class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable):
"""Holds a map from device to variables whose values are reduced on save."""
"""Holds a map from replica to variables whose values are reduced on save."""
def assign_sub(self, *args, **kwargs):
if _enclosing_tpu_context() is None: