PSv2: API update: Divide ParameterServerClient in user code to Client and ParameterServerStrategy objects, since we expect users to invoke methods on both objects per discussions. The implication is that we will be exposing both Client and ParameterServerStrategy APIs instead of only ParameterServerClient. Concretely,
[Before change] ``` client = tf.distribute.ParameterServerClient(cluster_resolver=...) distributed_dataset = client.create_per_worker_dataset(dataset_fn=...) with client.strategy.scope(): model, optimizer, metrics = ... @tf.function def worker_fn(iterator): def train_fn(iterator): # grab a batch, calculate gradient, applying gradient, metrics update etc. client.strategy.run(train_fn, args=(iterator,)) for epoch in range(num_epoch): for step in range(steps_per_epoch): client.schedule(worker_fn, args=(distributed_iterator,)) client.join() ``` [After change] ``` strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=...) client = tf.distribute.Client(strategy=strategy) distributed_dataset = client.create_per_worker_dataset(dataset_fn=...) with strategy.scope(): model, optimizer, metrics = ... @tf.function def worker_fn(iterator): def train_fn(iterator): # grab a batch, calculate gradient, applying gradient, metrics update etc. strategy.run(train_fn, args=(iterator,)) for epoch in range(num_epoch): for step in range(steps_per_epoch): client.schedule(worker_fn, args=(distributed_iterator,)) client.join() ``` PiperOrigin-RevId: 332065946 Change-Id: Ibce265e29d585fda8829f9c33409667992231480
This commit is contained in:
parent
8aefe626a9
commit
96a993a650
@ -147,7 +147,7 @@ py_library(
|
||||
":mirrored_strategy",
|
||||
":one_device_strategy",
|
||||
":sharded_variable",
|
||||
"//tensorflow/python/distribute/client:parameter_server_client",
|
||||
"//tensorflow/python/distribute/client",
|
||||
"//tensorflow/python/distribute/experimental",
|
||||
],
|
||||
)
|
||||
|
@ -7,23 +7,13 @@ package(
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "parameter_server_client",
|
||||
srcs = ["parameter_server_client.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":client",
|
||||
":utils",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "client",
|
||||
srcs = ["client.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":metric_utils",
|
||||
":utils",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:func_graph",
|
||||
@ -68,7 +58,6 @@ tf_py_test(
|
||||
tags = ["no_oss"], # TODO(b/162119374)
|
||||
deps = [
|
||||
":client",
|
||||
":parameter_server_client",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -94,7 +83,6 @@ tf_py_test(
|
||||
shard_count = 2,
|
||||
tags = ["no_oss"], # TODO(b/162119374)
|
||||
deps = [
|
||||
":parameter_server_client",
|
||||
":remote_eager_lib",
|
||||
":utils",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -102,7 +90,9 @@ tf_py_test(
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute:sharded_variable",
|
||||
"//tensorflow/python/distribute/client",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
|
@ -384,7 +384,7 @@ class _CoordinatedClosureQueue(object):
|
||||
|
||||
if _CLOSURE_QUEUE_MAX_SIZE <= 0:
|
||||
logging.warning(
|
||||
"In ParameterServerClient, creating an infinite closure queue can "
|
||||
"In a `Client`, creating an infinite closure queue can "
|
||||
"consume a significant amount of memory and even lead to OOM.")
|
||||
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
|
||||
self._error = None
|
||||
@ -846,9 +846,7 @@ class Client(object):
|
||||
functions to be executed, and fetch the results of the functions.
|
||||
|
||||
Currently, `Client` is not supported to be used in a standalone manner.
|
||||
It should be used in conjunction with `ParameterServerStrategyV2`. The
|
||||
recommended way of using the combination is through a `ParameterServerClient`
|
||||
object. Please see `ParameterServerClient` for more information.
|
||||
It should be used in conjunction with `ParameterServerStrategyV2`.
|
||||
|
||||
This is currently under development, and the API as well as implementation
|
||||
is subject to changes.
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Multi-process runner tests for parameter_server_client.py."""
|
||||
"""Multi-process runner tests for `Client` with `ParameterServerStrategyV2`."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -23,8 +23,8 @@ from absl import logging
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute.client import client
|
||||
from tensorflow.python.distribute.client import parameter_server_client
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import client as client_lib
|
||||
from tensorflow.python.distribute.client import utils
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -51,9 +51,11 @@ class ClientMprTest(test.TestCase):
|
||||
cluster_resolver = TFConfigClusterResolver()
|
||||
if cluster_resolver.task_type != "chief":
|
||||
utils.start_server(cluster_resolver, "grpc")
|
||||
ps_client = parameter_server_client.ParameterServerClient(
|
||||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||
cluster_resolver)
|
||||
with ps_client._strategy.scope():
|
||||
ps_client = client_lib.Client(strategy)
|
||||
|
||||
with strategy.scope():
|
||||
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
|
||||
|
||||
@def_function.function
|
||||
@ -77,7 +79,7 @@ class ClientMprTest(test.TestCase):
|
||||
while ps_client.cluster._closure_queue._error is None:
|
||||
time.sleep(1)
|
||||
ps_client.schedule(worker_fn)
|
||||
except client.ParameterServerFailureError:
|
||||
except client_lib.ParameterServerFailureError:
|
||||
# The following verifies that after PS fails, continue executing
|
||||
# functions on workers should fail and indicate it's PS failure.
|
||||
for worker_id in range(3):
|
||||
@ -87,7 +89,7 @@ class ClientMprTest(test.TestCase):
|
||||
# failure.
|
||||
worker_fn()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
if client._is_ps_failure(e):
|
||||
if client_lib._is_ps_failure(e):
|
||||
if worker_id < 2:
|
||||
continue
|
||||
logging.info("_test_translate_ps_failure_error ends properly.")
|
||||
|
@ -1,55 +0,0 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Parameter server client module.
|
||||
|
||||
This is currently under development and the API is subject to change.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import client
|
||||
|
||||
|
||||
class ParameterServerClient(client.Client):
|
||||
"""A client that uses `ParameterServerStrategy` to distribute tasks.
|
||||
|
||||
Parameter server training refers to the distributed training architecture
|
||||
that requires two jobs in the cluster: workers and parameter servers. The
|
||||
variables and updates to those variables are assigned on the parameter
|
||||
servers' tasks, and the actual computation intensive operations are assigned
|
||||
on worker tasks. In TF2, parameter server training only starts up one
|
||||
client process, to drive and coordinate the workers and parameter servers.
|
||||
This is referred to as single-client architecture, as opposed to multi-client
|
||||
approach which is seen more often in traditional TensorFlow distributed
|
||||
training, including `tf.estimator.Estimator` and `tf.keras` with
|
||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||
|
||||
`ParameterServerClient` is a `Client` that uses `ParameterServerStrategy` as
|
||||
the underlying strategy to distribute, and is the starting point of parameter
|
||||
server training/evaluation.
|
||||
|
||||
If 'TF_CONFIG' environment variable is used, provide a
|
||||
`TFConfigClusterResolver` to detect configurations for multi-worker training.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cluster_resolver, variable_partitioner=None):
|
||||
super(ParameterServerClient, self).__init__(
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||
cluster_resolver, variable_partitioner))
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for parameter_server_client.py."""
|
||||
"""Tests for `Client` when used together with `ParameterServerStrategyV2."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -26,8 +26,8 @@ from absl import logging
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute.client import client
|
||||
from tensorflow.python.distribute.client import parameter_server_client
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import client as client_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
@ -93,7 +93,9 @@ def make_client(num_workers, num_ps):
|
||||
]
|
||||
cluster_resolver = SimpleClusterResolver(
|
||||
ClusterSpec(cluster_def), rpc_layer="grpc")
|
||||
return parameter_server_client.ParameterServerClient(cluster_resolver)
|
||||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||
cluster_resolver)
|
||||
return client_lib.Client(strategy)
|
||||
|
||||
|
||||
class ParameterServerClientTest(TestCaseWithErrorReportingThread):
|
||||
@ -102,13 +104,14 @@ class ParameterServerClientTest(TestCaseWithErrorReportingThread):
|
||||
def setUpClass(cls):
|
||||
super(ParameterServerClientTest, cls).setUpClass()
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.client.strategy
|
||||
|
||||
def testBasic(self):
|
||||
self.client._strategy.extended._variable_count = 0
|
||||
with self.client.strategy.scope():
|
||||
self.strategy.extended._variable_count = 0
|
||||
with self.strategy.scope():
|
||||
v1 = variables.Variable(initial_value=0.0)
|
||||
v2 = variables.Variable(initial_value=1.0)
|
||||
self.assertEqual(self.client._strategy.extended._variable_count, 2)
|
||||
self.assertEqual(self.strategy.extended._variable_count, 2)
|
||||
|
||||
@def_function.function
|
||||
def worker_fn():
|
||||
@ -139,7 +142,7 @@ class ParameterServerClientTest(TestCaseWithErrorReportingThread):
|
||||
def input_fn():
|
||||
return dataset_ops.DatasetV2.range(1, 2)
|
||||
|
||||
with self.client.strategy.scope():
|
||||
with self.strategy.scope():
|
||||
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
|
||||
|
||||
@def_function.function
|
||||
@ -163,7 +166,7 @@ class ParameterServerClientTest(TestCaseWithErrorReportingThread):
|
||||
def input_fn():
|
||||
return dataset_ops.DatasetV2.from_tensor_slices([2] * 10)
|
||||
|
||||
with self.client.strategy.scope():
|
||||
with self.strategy.scope():
|
||||
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
|
||||
|
||||
# TODO(yuefengz): the following tf.function has a return value which is None
|
||||
@ -268,7 +271,7 @@ class ParameterServerClientTest(TestCaseWithErrorReportingThread):
|
||||
class LimitedClosureQueueSizeBasicTest(ParameterServerClientTest):
|
||||
"""Test basic functionality works with explicit maximum closure queue size.
|
||||
|
||||
Execute the same set of test cases as in ParameterServerClientTest, with an
|
||||
Execute the same set of test cases as in `ParameterServerClientTest`, with an
|
||||
explicit size limit for the closure queue. Note that even when the queue size
|
||||
is set to infinite, there is still a maximum practical size (depends on host
|
||||
memory limit) that might cause the queue.put operations to be blocking when
|
||||
@ -279,8 +282,9 @@ class LimitedClosureQueueSizeBasicTest(ParameterServerClientTest):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
|
||||
client._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.client.strategy
|
||||
|
||||
|
||||
class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
||||
@ -289,8 +293,9 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
||||
def setUpClass(cls):
|
||||
super(ErrorReportingTest, cls).setUpClass()
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.client.strategy
|
||||
|
||||
with cls.client.strategy.scope():
|
||||
with cls.strategy.scope():
|
||||
cls.iteration = variables.Variable(initial_value=0.0)
|
||||
|
||||
@def_function.function
|
||||
@ -373,10 +378,10 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
||||
self.client.join()
|
||||
|
||||
result = self.client.schedule(func, args=(aborted,))
|
||||
with self.assertRaises(client.InputError):
|
||||
with self.assertRaises(client_lib.InputError):
|
||||
result.fetch()
|
||||
|
||||
with self.assertRaises(client.InputError):
|
||||
with self.assertRaises(client_lib.InputError):
|
||||
self.client.join()
|
||||
|
||||
def testCancellation(self):
|
||||
@ -388,7 +393,7 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
|
||||
with self.assertRaises(client.FunctionRetryableError):
|
||||
with self.assertRaises(client_lib.FunctionRetryableError):
|
||||
long_function.fetch()
|
||||
|
||||
for _ in range(3):
|
||||
@ -406,8 +411,9 @@ class LimitedClosureQueueErrorTest(ErrorReportingTest):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(LimitedClosureQueueErrorTest, cls).setUpClass()
|
||||
client._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
cls.strategy = cls.client.strategy
|
||||
|
||||
with cls.client.strategy.scope():
|
||||
cls.iteration = variables.Variable(initial_value=0.0)
|
||||
@ -419,10 +425,11 @@ class StrategyRunTest(test.TestCase):
|
||||
def setUpClass(cls):
|
||||
super(StrategyRunTest, cls).setUpClass()
|
||||
cls.client = make_client(num_workers=1, num_ps=1)
|
||||
cls.strategy = cls.client.strategy
|
||||
|
||||
def testStrategyRun(self):
|
||||
self.assertFalse(distribution_strategy_context.in_cross_replica_context())
|
||||
with self.client._strategy.scope():
|
||||
with self.strategy.scope():
|
||||
self.assertTrue(distribution_strategy_context.in_cross_replica_context())
|
||||
v = variables.Variable(initial_value=1)
|
||||
|
||||
@ -435,11 +442,11 @@ class StrategyRunTest(test.TestCase):
|
||||
distribution_strategy_context.in_cross_replica_context())
|
||||
return input_tensor + v
|
||||
|
||||
return self.client._strategy.run(replica_fn, args=(input_tensor,))
|
||||
return self.strategy.run(replica_fn, args=(input_tensor,))
|
||||
|
||||
# Asserting scheduling in scope has the expected behavior.
|
||||
result = self.client.schedule(worker_fn, args=(constant_op.constant(3),))
|
||||
self.assertIsInstance(result, client.RemoteValue)
|
||||
self.assertIsInstance(result, client_lib.RemoteValue)
|
||||
self.assertEqual(result.fetch(), 4)
|
||||
|
||||
# Asserting scheduling out of scope has the expected behavior.
|
||||
|
@ -42,10 +42,8 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
|
||||
"""An asynchronous multi-worker parameter server tf.distribute strategy.
|
||||
|
||||
Currently, `ParameterServerStrategyV2` is not supported to be used as a
|
||||
standalone tf.distribute strategy. It must be used in conjunction with
|
||||
`Client`. The recommended way of using the combination is through a
|
||||
`ParameterServerClient` object. Please see `Client` and
|
||||
`ParameterServerClient` for more information.
|
||||
standalone tf.distribute strategy. It should be used in conjunction with
|
||||
`Client`. Please see `Client` for more information.
|
||||
|
||||
This is currently under development, and the API as well as implementation
|
||||
is subject to changes.
|
||||
|
@ -854,8 +854,9 @@ py_test(
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute:sharded_variable",
|
||||
"//tensorflow/python/distribute/client:parameter_server_client",
|
||||
"//tensorflow/python/distribute/client",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
|
@ -26,7 +26,8 @@ from tensorflow.python import keras
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute.client import parameter_server_client
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import client as client_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -51,7 +52,8 @@ def make_client(num_workers, num_ps):
|
||||
]
|
||||
cluster_resolver = SimpleClusterResolver(
|
||||
ClusterSpec(cluster_def), rpc_layer="grpc")
|
||||
return parameter_server_client.ParameterServerClient(cluster_resolver)
|
||||
return client_lib.Client(
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
|
||||
|
||||
|
||||
class KPLTest(test.TestCase):
|
||||
|
@ -149,8 +149,8 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/tools/common:public_api",
|
||||
"//tensorflow/tools/common:test_module1",
|
||||
"//tensorflow/tools/common:traverse",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute/client:client",
|
||||
"//tensorflow/python/distribute/client:parameter_server_client",
|
||||
"//tensorflow/python/distribute/client:remote_eager_lib",
|
||||
"//tensorflow/python/distribute/client:metric_utils",
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user