From e0d9dfd54b0f2f4a2a6637268aa176d1701ee8e9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jun 2019 21:01:27 -0700 Subject: [PATCH] tf.distribute clean ups. PiperOrigin-RevId: 255533911 --- .../distribute/python/mirrored_strategy.py | 2 -- .../distribute/collective_all_reduce_strategy.py | 6 +++--- .../python/distribute/mirrored_strategy.py | 6 +++--- .../python/distribute/mirrored_strategy_test.py | 2 +- tensorflow/python/distribute/values.py | 16 ++++++++-------- 5 files changed, 15 insertions(+), 17 deletions(-) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 82e3919d7d3..8b3c5e546d6 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -23,11 +23,9 @@ from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_strategy -# pylint: disable=protected-access,invalid-name all_local_devices = mirrored_strategy.all_local_devices CoreMirroredStrategy = mirrored_strategy.MirroredStrategy CoreMirroredExtended = mirrored_strategy.MirroredExtended -# pylint: enable=protected-access,invalid-name class MirroredStrategy(distribute_lib.StrategyV1): diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index f0e86f71c6a..e858b6a57fc 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -282,9 +282,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): self._communication) def _get_variable_creator_initial_value(self, - replica_id=0, - device=None, - primary_var=None, + replica_id, + device, + primary_var, **kwargs): if replica_id == 0: # First replica on each worker. assert device is not None diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 722db406f63..b44f73d3a5d 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -442,9 +442,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce() def _get_variable_creator_initial_value(self, - replica_id=0, - device=None, - primary_var=None, + replica_id, + device, + primary_var, **kwargs): """Return the initial value for variables on a replica.""" if replica_id == 0: diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index df197998396..7e606dbd500 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Multi-GPU tests for MirroredStrategy.""" +"""Tests for MirroredStrategy.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 386e0c0e3a0..8b935a345ba 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -298,7 +298,7 @@ class WorkerDeviceMap(DeviceMap): class DistributedValues(object): - """Holds a map from device to values. Either PerReplica or Mirrored.""" + """Holds a map from replica to values. Either PerReplica or Mirrored.""" def __init__(self, device_map, values, logical_device=None): assert isinstance(device_map, DeviceMap) @@ -463,7 +463,7 @@ class DistributedDelegate(DistributedValues): class PerReplica(DistributedValues, composite_tensor.CompositeTensor): - """Holds a map from device to unsynchronized values.""" + """Holds a map from replica to unsynchronized values.""" @property def _type_spec(self): @@ -536,7 +536,7 @@ class PerReplicaSpec(type_spec.TypeSpec): # DistributedDelegate and so can be used directly in cross-replica mode. # TODO(tomhennigan) Should this extend CompositeTensor? class Mirrored(DistributedDelegate): - """Holds a map from device to values which are kept in sync.""" + """Holds a map from replica to values which are kept in sync.""" def _get_cross_replica(self): device = device_util.canonicalize(device_util.current()) @@ -595,7 +595,7 @@ DistributedVarOp = collections.namedtuple( class DistributedVariable(DistributedDelegate, variables_lib.AbstractVariable): - """Holds a map from device to variables.""" + """Holds a map from replica to variables.""" # TODO(josh11b): Support changing the set of variables if e.g. if new # devices are joining or a device is to leave. @@ -968,7 +968,7 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): class MirroredVariable(DistributedVariable, Mirrored): - """Holds a map from device to variables whose values are kept in sync.""" + """Holds a map from replica to variables whose values are kept in sync.""" def __init__( self, strategy, device_map, values, aggregation, logical_device=None): @@ -1094,7 +1094,7 @@ def is_distributed_variable(v): class TPUMirroredVariable(TPUVariableMixin, MirroredVariable): - """Holds a map from device to TPU variables whose values are kept in sync.""" + """Holds a map from replica to TPU variables whose values are kept in sync.""" def _assign_func(self, *args, **kwargs): with _enter_or_assert_strategy(self._distribute_strategy): @@ -1158,7 +1158,7 @@ def _assert_replica_context(strategy): class SyncOnReadVariable(DistributedVariable, PerReplica): - """Holds a map from device to variables whose values are reduced on save.""" + """Holds a map from replica to variables whose values are reduced on save.""" def __init__( self, strategy, device_map, values, aggregation, logical_device=None): @@ -1255,7 +1255,7 @@ ops.register_tensor_conversion_function(SyncOnReadVariable, class TPUSyncOnReadVariable(TPUVariableMixin, SyncOnReadVariable): - """Holds a map from device to variables whose values are reduced on save.""" + """Holds a map from replica to variables whose values are reduced on save.""" def assign_sub(self, *args, **kwargs): if _enclosing_tpu_context() is None: