Split ParameterServerStrategy into multi-worker and local version.

PiperOrigin-RevId: 243347187
This commit is contained in:
Yuefeng Zhou 2019-04-12 15:46:08 -07:00 committed by TensorFlower Gardener
parent f47745d60a
commit fec0d5fd20
20 changed files with 362 additions and 111 deletions

View File

@ -56,10 +56,12 @@ cuda_py_test(
name = "parameter_server_strategy_test", name = "parameter_server_strategy_test",
srcs = ["parameter_server_strategy_test.py"], srcs = ["parameter_server_strategy_test.py"],
additional_deps = [ additional_deps = [
":parameter_server_strategy",
"//tensorflow/python/distribute:central_storage_strategy",
"//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:parameter_server_strategy",
"//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:multi_worker_test_base",
":parameter_server_strategy",
"//tensorflow/python/distribute:strategy_test_lib", "//tensorflow/python/distribute:strategy_test_lib",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",

View File

@ -24,9 +24,9 @@ from absl.testing import parameterized
from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import device_util from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import multi_worker_util
@ -52,7 +52,6 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import training_util from tensorflow.python.training import training_util
from tensorflow.python.training.server_lib import ClusterSpec
CHIEF = run_config.TaskType.CHIEF CHIEF = run_config.TaskType.CHIEF
WORKER = run_config.TaskType.WORKER WORKER = run_config.TaskType.WORKER
@ -66,15 +65,6 @@ def _get_replica_id_integer():
return replica_id return replica_id
class MockCoreParameterServerStrategy(distribute_lib.StrategyV1):
"""Mock the strategy to allow cluster resolver as an argument."""
def __init__(self, cluster_resolver):
super(MockCoreParameterServerStrategy, self).__init__(
core_parameter_server_strategy.ParameterServerStrategyExtended(
self, cluster_resolver=cluster_resolver))
def create_test_objects(cluster_spec=None, def create_test_objects(cluster_spec=None,
task_type=None, task_type=None,
task_id=None, task_id=None,
@ -91,13 +81,15 @@ def create_test_objects(cluster_spec=None,
task_type=task_type, task_type=task_type,
task_id=task_id, task_id=task_id,
num_accelerators={'GPU': num_gpus}) num_accelerators={'GPU': num_gpus})
distribution = core_parameter_server_strategy.ParameterServerStrategy(
cluster_resolver)
target = 'grpc://' + cluster_spec[WORKER][task_id] target = 'grpc://' + cluster_spec[WORKER][task_id]
else: else:
cluster_resolver = SimpleClusterResolver( distribution = (
ClusterSpec({}), num_accelerators={'GPU': num_gpus}) central_storage_strategy.CentralStorageStrategy._from_num_gpus(
num_gpus))
target = '' target = ''
distribution = MockCoreParameterServerStrategy(cluster_resolver)
sess_config = copy.deepcopy(sess_config) sess_config = copy.deepcopy(sess_config)
sess_config = distribution.update_config_proto(sess_config) sess_config = distribution.update_config_proto(sess_config)
else: else:
@ -440,7 +432,8 @@ class ParameterServerStrategyTestBase(
x, y, z, train_op = d.extended.call_for_each_replica(model_fn) x, y, z, train_op = d.extended.call_for_each_replica(model_fn)
train_op = d.group(train_op) train_op = d.group(train_op)
if context.num_gpus() < d.extended._num_gpus_per_worker: if context.num_gpus() < sum(
1 for d in d.extended.worker_devices if 'GPU' in d.upper()):
return True return True
if task_id == 0: if task_id == 0:
@ -536,7 +529,8 @@ class ParameterServerStrategyTestBase(
before_out, after_out = step() before_out, after_out = step()
if context.num_gpus() < d.extended._num_gpus_per_worker: if context.num_gpus() < sum(
1 for d in d.extended.worker_devices if 'GPU' in d.upper()):
return True return True
if (not task_type or if (not task_type or
@ -778,9 +772,11 @@ class ParameterServerStrategyTest(
combinations.combine(mode=['graph'], use_core_strategy=[True, False])) combinations.combine(mode=['graph'], use_core_strategy=[True, False]))
def testUpdateConfigProtoMultiWorker(self, use_core_strategy): def testUpdateConfigProtoMultiWorker(self, use_core_strategy):
strategy, _, _ = create_test_objects( strategy, _, _ = create_test_objects(
num_gpus=2, use_core_strategy=use_core_strategy) cluster_spec=self._cluster_spec,
strategy.configure( task_type='worker',
cluster_spec=self._cluster_spec, task_type='worker', task_id=1) task_id=1,
num_gpus=2,
use_core_strategy=use_core_strategy)
config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden'])
@ -923,8 +919,8 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
strategy.extended.call_for_each_replica(f) strategy.extended.call_for_each_replica(f)
class LocalParameterServerStrategyTest(strategy_test_lib.DistributionTestBase, class CentralStorageStrategyTest(strategy_test_lib.DistributionTestBase,
parameterized.TestCase): parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager'], @combinations.generate(combinations.combine(mode=['graph', 'eager'],
use_core_strategy=[True, False], use_core_strategy=[True, False],

View File

@ -276,6 +276,18 @@ py_library(
], ],
) )
py_library(
name = "central_storage_strategy",
srcs = ["central_storage_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
":device_util",
":distribute_lib",
":parameter_server_strategy",
"//tensorflow/python:util",
],
)
py_library( py_library(
name = "one_device_strategy", name = "one_device_strategy",
srcs = ["one_device_strategy.py"], srcs = ["one_device_strategy.py"],
@ -531,11 +543,11 @@ py_library(
srcs = ["strategy_combinations.py"], srcs = ["strategy_combinations.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":central_storage_strategy",
":combinations", ":combinations",
":distribute_lib", ":distribute_lib",
":mirrored_strategy", ":mirrored_strategy",
":one_device_strategy", ":one_device_strategy",
":parameter_server_strategy",
":tpu_strategy", ":tpu_strategy",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:training", "//tensorflow/python:training",

View File

@ -0,0 +1,66 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes implementing a multi-worker ps DistributionStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.util.tf_export import tf_export
@tf_export("distribute.experimental.CentralStorageStrategy", v1=[])
class CentralStorageStrategy(distribute_lib.Strategy):
"""A one-machine strategy that puts all variables on a single device.
Variables are assigned to local CPU or the only GPU. If there is more
than one GPU, compute operations (other than variable update operations)
will be replicated across all GPUs.
Args:
compute_devices: an optional list of strings for device to replicate models
on. If this is not provided, all local GPUs will be used; if there is no
GPU, local CPU will be used.
parameter_device: an optional device string for which device to put
variables on. The default one is CPU or GPU if there is only one.
"""
def __init__(self, compute_devices=None, parameter_device=None):
extended = parameter_server_strategy.ParameterServerStrategyExtended(
self,
compute_devices=compute_devices,
parameter_device=parameter_device)
super(CentralStorageStrategy, self).__init__(extended)
@classmethod
def _from_num_gpus(cls, num_gpus):
return cls(device_util.local_devices_from_num_gpus(num_gpus))
@tf_export(v1=["distribute.experimental.CentralStorageStrategy"])
class CentralStorageStrategyV1(distribute_lib.StrategyV1):
__doc__ = CentralStorageStrategy.__doc__
def __init__(self, compute_devices=None, parameter_device=None):
"""Initializes this strategy with default TFConfigClusterResolver."""
super(CentralStorageStrategyV1, self).__init__(
parameter_server_strategy.ParameterServerStrategyExtended(
self,
compute_devices=compute_devices,
parameter_device=parameter_device))

View File

@ -108,3 +108,9 @@ def get_host_for_device(device):
return tf_device.DeviceSpec( return tf_device.DeviceSpec(
job=spec.job, replica=spec.replica, task=spec.task, job=spec.job, replica=spec.replica, task=spec.task,
device_type="CPU", device_index=0).to_string() device_type="CPU", device_index=0).to_string()
def local_devices_from_num_gpus(num_gpus):
"""Returns device strings for local GPUs or CPU."""
return (tuple("/device:GPU:%d" % i for i in range(num_gpus)) or
("/device:CPU:0",))

View File

@ -13,6 +13,7 @@ py_library(
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/python/distribute:central_storage_strategy",
"//tensorflow/python/distribute:collective_all_reduce_strategy", "//tensorflow/python/distribute:collective_all_reduce_strategy",
"//tensorflow/python/distribute:parameter_server_strategy", "//tensorflow/python/distribute:parameter_server_strategy",
"//tensorflow/python/distribute:tpu_strategy", "//tensorflow/python/distribute:tpu_strategy",

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
# pylint: disable=unused-import # pylint: disable=unused-import
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_strategy

View File

@ -407,8 +407,7 @@ def _infer_num_gpus_per_worker(devices):
def all_local_devices(num_gpus=None): def all_local_devices(num_gpus=None):
if num_gpus is None: if num_gpus is None:
num_gpus = context.num_gpus() num_gpus = context.num_gpus()
return (tuple("/device:GPU:%d" % i for i in range(num_gpus)) or return device_util.local_devices_from_num_gpus(num_gpus)
("/device:CPU:0",))
def _all_devices(): def _all_devices():

View File

@ -43,32 +43,32 @@ from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
_LOCAL_CPU = "/device:CPU:0" _LOCAL_CPU = "/device:CPU:0"
_LOCAL_GPU_0 = "/device:GPU:0"
# TODO(yuefengz): maybe cache variables on local CPU. # TODO(yuefengz): maybe cache variables on local CPU.
@tf_export("distribute.experimental.ParameterServerStrategy", v1=[]) @tf_export("distribute.experimental.ParameterServerStrategy", v1=[])
class ParameterServerStrategy(distribute_lib.Strategy): class ParameterServerStrategy(distribute_lib.Strategy):
"""A parameter server DistributionStrategy. """An asynchronous multi-worker parameter server DistributionStrategy.
This strategy requires two jobs: workers and parameter servers. Variables and
updates to those variables will be assigned to parameter servers and other
operations are assigned to workers.
This strategy class works for both local training and between-graph replicated
training for multiple workers. It uses `TFConfigClusterResolver` to detect
configurations for multi-worker training. In multi-worker training mode, i.e.
`TFConfigClusterResolver` has detected 'TF_CONFIG' environment variable and
'TF_CONFIG' has a cluster spec, variables and updates to those variables are
assigned to parameter servers and other operations are assigned to workers.
In local training mode, variables are assigned to local CPU or the only GPU.
When each worker has more than one GPU, operations will be replicated on these When each worker has more than one GPU, operations will be replicated on these
GPUs. In both cases, operations are replicated but variables are not and these GPUs. Even though operations may be replicated, variables are not and each
workers share a common view for which parameter server a variable is assigned worker shares a common view for which parameter server a variable is assigned
to. to.
This class assumes between-graph replication will be used and works on a graph By default it uses `TFConfigClusterResolver` to detect configurations for
for a particular worker. Note that each graph and worker is independent. multi-worker training. This requires a 'TF_CONFIG' environment variable and
This means that while each worker will synchronously compute a single gradient the 'TF_CONFIG' must have a cluster spec.
update across all GPUs, updates between workers proceed asynchronously.
Operations that occur only on the first replica (such as incrementing the This class assumes each worker is running the same code independently, but
global step), will occur on the first replica *of every worker*. parameter servers are running a standard server. This means that while each
worker will synchronously compute a single gradient update across all GPUs,
updates between workers proceed asynchronously. Operations that occur only on
the first replica (such as incrementing the global step), will occur on the
first replica *of every worker*.
It is expected to call `call_for_each_replica(fn, ...)` for any It is expected to call `call_for_each_replica(fn, ...)` for any
operations which potentially can be replicated across replicas (i.e. multiple operations which potentially can be replicated across replicas (i.e. multiple
@ -86,10 +86,21 @@ class ParameterServerStrategy(distribute_lib.Strategy):
possibly create conflicts of device assignment. possibly create conflicts of device assignment.
""" """
def __init__(self): def __init__(self, cluster_resolver=None):
"""Initializes this strategy with default TFConfigClusterResolver.""" """Initializes this strategy.
super(ParameterServerStrategy, self).__init__(
ParameterServerStrategyExtended(self)) Args:
cluster_resolver: Optional
`tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
`tf.distribute.cluster_resolver.TFConfigClusterResolver`.
"""
if cluster_resolver is None:
cluster_resolver = TFConfigClusterResolver()
if not cluster_resolver.cluster_spec():
raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
extended = ParameterServerStrategyExtended(
self, cluster_resolver=cluster_resolver)
super(ParameterServerStrategy, self).__init__(extended)
@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) @tf_export(v1=["distribute.experimental.ParameterServerStrategy"])
@ -97,31 +108,41 @@ class ParameterServerStrategyV1(distribute_lib.StrategyV1):
__doc__ = ParameterServerStrategy.__doc__ __doc__ = ParameterServerStrategy.__doc__
def __init__(self): def __init__(self, cluster_resolver=None):
"""Initializes this strategy with default TFConfigClusterResolver.""" """Initializes this strategy."""
super(ParameterServerStrategyV1, self).__init__( super(ParameterServerStrategyV1, self).__init__(
ParameterServerStrategyExtended(self)) ParameterServerStrategyExtended(
self, cluster_resolver=cluster_resolver))
# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. # TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
"""Implementation of ParameterServerStrategy.""" """Implementation of ParameterServerStrategy and CentralStorageStrategy."""
def __init__(self, def __init__(self,
container_strategy, container_strategy,
cluster_resolver=TFConfigClusterResolver()): cluster_resolver=None,
compute_devices=None,
parameter_device=None):
super(ParameterServerStrategyExtended, self).__init__(container_strategy) super(ParameterServerStrategyExtended, self).__init__(container_strategy)
self._initialize_strategy(cluster_resolver) self._initialize_strategy(
cluster_resolver=cluster_resolver,
compute_devices=compute_devices,
parameter_device=parameter_device)
# We typically don't need to do all-reduce in this strategy. # We typically don't need to do all-reduce in this strategy.
self._cross_device_ops = ( self._cross_device_ops = (
cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU)) cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))
def _initialize_strategy(self, cluster_resolver): def _initialize_strategy(self,
if cluster_resolver.cluster_spec().as_dict(): cluster_resolver=None,
compute_devices=None,
parameter_device=None):
if cluster_resolver and cluster_resolver.cluster_spec():
self._initialize_multi_worker(cluster_resolver) self._initialize_multi_worker(cluster_resolver)
else: else:
self._initialize_local(cluster_resolver) self._initialize_local(
compute_devices, parameter_device, cluster_resolver=cluster_resolver)
def _initialize_multi_worker(self, cluster_resolver): def _initialize_multi_worker(self, cluster_resolver):
"""Initialize devices for multiple workers. """Initialize devices for multiple workers.
@ -214,43 +235,41 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
num_ps_replicas, self._is_chief, self._device_map, num_ps_replicas, self._is_chief, self._device_map,
self._variable_device) self._variable_device)
def _initialize_local(self, cluster_resolver): # TODO(yuefengz): get rid of cluster_resolver argument when contrib's
# version no longer depends on this class.
def _initialize_local(self,
compute_devices,
parameter_device,
cluster_resolver=None):
"""Initialize internal devices for local training.""" """Initialize internal devices for local training."""
worker_device = device_util.canonicalize("/device:CPU:0") worker_device = device_util.canonicalize("/device:CPU:0")
self._input_host_device = numpy_dataset.SingleDevice(worker_device) self._input_host_device = numpy_dataset.SingleDevice(worker_device)
# TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in if compute_devices is None:
# some cases. if not cluster_resolver:
if isinstance(cluster_resolver, TFConfigClusterResolver): num_gpus = context.num_gpus()
num_gpus = context.num_gpus() else:
else: num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) # Save the num_gpus_per_worker for configure method which is used by the
# contrib version.
self._num_gpus_per_worker = num_gpus
# Save the num_gpus_per_worker for configure method. compute_devices = device_util.local_devices_from_num_gpus(num_gpus)
self._num_gpus_per_worker = num_gpus
# Define compute devices which is a list of device strings and one for each if parameter_device is None:
# replica. When there are GPUs, replicate operations on these GPUs. # If there is only one GPU, put everything on that GPU. Otherwise, place
# Otherwise, place operations on CPU. # variables on CPU.
if num_gpus > 0: if len(compute_devices) == 1:
compute_devices = tuple(map("/device:GPU:{}".format, range(num_gpus))) parameter_device = compute_devices[0]
else: else:
compute_devices = (_LOCAL_CPU,) parameter_device = _LOCAL_CPU
self._device_map = values.ReplicaDeviceMap(compute_devices) self._device_map = values.ReplicaDeviceMap(compute_devices)
self._input_workers = input_lib.InputWorkers( self._input_workers = input_lib.InputWorkers(
self._device_map, [(worker_device, compute_devices)]) self._device_map, [(worker_device, compute_devices)])
# If there is only one GPU, put everything on that GPU. Otherwise, place self._variable_device = parameter_device
# variables on CPU. self._parameter_devices = (parameter_device,)
if num_gpus == 1:
assert len(compute_devices) == 1
self._variable_device = _LOCAL_GPU_0
self._parameter_devices = (_LOCAL_GPU_0,)
else:
self._variable_device = _LOCAL_CPU
self._parameter_devices = (_LOCAL_CPU,)
self._is_chief = True self._is_chief = True
self._cluster_spec = None self._cluster_spec = None
self._task_type = None self._task_type = None

View File

@ -17,6 +17,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy as mirrored_lib from tensorflow.python.distribute import mirrored_strategy as mirrored_lib
@ -100,6 +101,10 @@ mirrored_strategy_with_two_gpus = combinations.NamedDistribution(
"Mirrored2GPUs", "Mirrored2GPUs",
lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]), lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]),
required_gpus=2) required_gpus=2)
central_storage_strategy_with_two_gpus = combinations.NamedDistribution(
"CentralStorage2GPUs",
lambda: central_storage_strategy.CentralStorageStrategy._from_num_gpus(2), # pylint: disable=protected-access
required_gpus=2)
gradient_descent_optimizer_v1_fn = combinations.NamedObject( gradient_descent_optimizer_v1_fn = combinations.NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))

View File

@ -23,7 +23,6 @@ from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import device_util from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -40,16 +39,6 @@ from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.tracking import util as trackable_utils from tensorflow.python.training.tracking import util as trackable_utils
# TODO(rchao): Merge parameter_server_strategy_with_two_gpus into
# third_party/tensorflow/python/distribute/strategy_combinations.py
# pylint: disable=g-long-lambda
parameter_server_strategy_with_two_gpus = combinations.NamedDistribution(
"ParameterServer2GPUs",
lambda: parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=2),
required_gpus=2)
class DistributedValuesTest(test.TestCase): class DistributedValuesTest(test.TestCase):
def testGetEager(self): def testGetEager(self):
@ -561,7 +550,9 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=[parameter_server_strategy_with_two_gpus], distribution=[
strategy_combinations.central_storage_strategy_with_two_gpus
],
mode=["graph", "eager"])) mode=["graph", "eager"]))
def testAssignOutOfScope_aggregating(self, distribution): def testAssignOutOfScope_aggregating(self, distribution):
with distribution.scope(): with distribution.scope():
@ -577,7 +568,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
parameter_server_strategy_with_two_gpus, strategy_combinations.central_storage_strategy_with_two_gpus,
], ],
mode=["graph", "eager"])) mode=["graph", "eager"]))
def testExtendsVariable(self, distribution): def testExtendsVariable(self, distribution):
@ -591,7 +582,7 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
parameter_server_strategy_with_two_gpus, strategy_combinations.central_storage_strategy_with_two_gpus,
], ],
mode=["graph", "eager"])) mode=["graph", "eager"]))
def testCheckpointing(self, distribution): def testCheckpointing(self, distribution):

View File

@ -318,10 +318,12 @@ py_library(
"//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:distribute_config", "//tensorflow/python/distribute:distribute_config",
"//tensorflow/python/distribute:distribute_coordinator", "//tensorflow/python/distribute:distribute_coordinator",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:mirrored_strategy", "//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:parameter_server_strategy", "//tensorflow/python/distribute:parameter_server_strategy",
"//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/keras", "//tensorflow/python/keras",
], ],

View File

@ -23,7 +23,7 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -36,15 +36,6 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
# TODO(rchao): Merge parameter_server_strategy_with_two_gpus into
# third_party/tensorflow/python/distribute/strategy_combinations.py
# pylint: disable=g-long-lambda
parameter_server_strategy_with_two_gpus = combinations.NamedDistribution(
'ParameterServer2GPUs',
lambda: parameter_server_strategy.ParameterServerStrategy(),
required_gpus=2)
def get_model(): def get_model():
x = keras.layers.Input(shape=(3,), name='input') x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x) y = keras.layers.Dense(4, name='dense')(x)
@ -57,7 +48,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
parameter_server_strategy_with_two_gpus, strategy_combinations.central_storage_strategy_with_two_gpus,
], ],
mode=['graph', 'eager'])) mode=['graph', 'eager']))
def testKerasOptimizerWithUnequalInput(self, distribution): def testKerasOptimizerWithUnequalInput(self, distribution):
@ -114,7 +105,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
parameter_server_strategy_with_two_gpus, strategy_combinations.central_storage_strategy_with_two_gpus,
], ],
mode=['graph', 'eager'])) mode=['graph', 'eager']))
def testOptimizerWithKerasModelAndNumpyArrays(self, distribution): def testOptimizerWithKerasModelAndNumpyArrays(self, distribution):

View File

@ -34,9 +34,11 @@ from tensorflow.python.distribute import collective_all_reduce_strategy as colle
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_test_base as test_base from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
@ -51,6 +53,23 @@ from tensorflow.python.platform import test
from tensorflow.python.util import nest from tensorflow.python.util import nest
# TODO(b/130375202): remove this class which is a temporary solution before we
# get rid of configure method.
class ParameterServerStrategy(distribute_lib.Strategy):
"""Temporarily mock the original strategy to bypass cluster_spec check."""
def __init__(self, cluster_resolver=None):
"""Initializes this strategy."""
# The `cluster_resolver` must be set so that
# `ParameterServerStrategyExtended` will keep num_gpus for `configure`
# method.
if cluster_resolver is None:
cluster_resolver = TFConfigClusterResolver()
extended = parameter_server_strategy.ParameterServerStrategyExtended(
self, cluster_resolver=cluster_resolver)
super(ParameterServerStrategy, self).__init__(extended)
def _mnist_synthetic_dataset(batch_size, steps_per_epoch): def _mnist_synthetic_dataset(batch_size, steps_per_epoch):
# train dataset # train dataset
x_train = array_ops.ones([batch_size * steps_per_epoch, 28, 28, 1], x_train = array_ops.ones([batch_size * steps_per_epoch, 28, 28, 1],
@ -301,7 +320,7 @@ class KerasMultiWorkerTestStandaloneClient(test.TestCase,
mode=['graph'], mode=['graph'],
strategy_cls=[ strategy_cls=[
mirrored_strategy.MirroredStrategy, mirrored_strategy.MirroredStrategy,
parameter_server_strategy.ParameterServerStrategy, ParameterServerStrategy,
collective_strategy.CollectiveAllReduceStrategy, collective_strategy.CollectiveAllReduceStrategy,
], ],
required_gpus=[0, 1])) required_gpus=[0, 1]))
@ -383,7 +402,7 @@ class KerasMultiWorkerTestIndependentWorker(test_base.IndependentWorkerTestBase,
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
mode=['graph'], mode=['graph'],
strategy_cls=[parameter_server_strategy.ParameterServerStrategy], strategy_cls=[ParameterServerStrategy],
required_gpus=[0, 1])) required_gpus=[0, 1]))
def testSimpleModelIndependentWorkerAsync(self, strategy_cls): def testSimpleModelIndependentWorkerAsync(self, strategy_cls):
num_workers = 2 num_workers = 2

View File

@ -0,0 +1,67 @@
path: "tensorflow.distribute.experimental.CentralStorageStrategy"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.central_storage_strategy.CentralStorageStrategyV1\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<type \'object\'>"
member {
name: "extended"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'compute_devices\', \'parameter_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "make_dataset_iterator"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_input_fn_iterator"
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "unwrap"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update_config_proto"
argspec: "args=[\'self\', \'config_proto\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -14,7 +14,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "colocate_vars_with" name: "colocate_vars_with"

View File

@ -1,5 +1,9 @@
path: "tensorflow.distribute.experimental" path: "tensorflow.distribute.experimental"
tf_module { tf_module {
member {
name: "CentralStorageStrategy"
mtype: "<type \'type\'>"
}
member { member {
name: "CollectiveCommunication" name: "CollectiveCommunication"
mtype: "<class \'enum.EnumMeta\'>" mtype: "<class \'enum.EnumMeta\'>"

View File

@ -0,0 +1,66 @@
path: "tensorflow.distribute.experimental.CentralStorageStrategy"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.central_storage_strategy.CentralStorageStrategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<type \'object\'>"
member {
name: "extended"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'compute_devices\', \'parameter_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "make_dataset_iterator"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_input_fn_iterator"
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "unwrap"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update_config_proto"
argspec: "args=[\'self\', \'config_proto\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -13,7 +13,7 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "colocate_vars_with" name: "colocate_vars_with"

View File

@ -1,5 +1,9 @@
path: "tensorflow.distribute.experimental" path: "tensorflow.distribute.experimental"
tf_module { tf_module {
member {
name: "CentralStorageStrategy"
mtype: "<type \'type\'>"
}
member { member {
name: "CollectiveCommunication" name: "CollectiveCommunication"
mtype: "<class \'enum.EnumMeta\'>" mtype: "<class \'enum.EnumMeta\'>"