PSv2: API update: Divide ParameterServerClient in user code to Client and ParameterServerStrategy objects, since we expect users to invoke methods on both objects per discussions. The implication is that we will be exposing both Client and ParameterServerStrategy APIs instead of only ParameterServerClient. Concretely,

[Before change]
```
client = tf.distribute.ParameterServerClient(cluster_resolver=...)
distributed_dataset = client.create_per_worker_dataset(dataset_fn=...)
with client.strategy.scope():
  model, optimizer, metrics = ...

@tf.function
def worker_fn(iterator):
  def train_fn(iterator):
    # grab a batch, calculate gradient, applying gradient, metrics update etc.
  client.strategy.run(train_fn, args=(iterator,))

for epoch in range(num_epoch):
  for step in range(steps_per_epoch):
    client.schedule(worker_fn, args=(distributed_iterator,))
  client.join()
```

[After change]
```
strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=...)
client = tf.distribute.Client(strategy=strategy)
distributed_dataset = client.create_per_worker_dataset(dataset_fn=...)
with strategy.scope():
  model, optimizer, metrics = ...

@tf.function
def worker_fn(iterator):
  def train_fn(iterator):
    # grab a batch, calculate gradient, applying gradient, metrics update etc.
  strategy.run(train_fn, args=(iterator,))

for epoch in range(num_epoch):
  for step in range(steps_per_epoch):
    client.schedule(worker_fn, args=(distributed_iterator,))
  client.join()
```

PiperOrigin-RevId: 332065946
Change-Id: Ibce265e29d585fda8829f9c33409667992231480
This commit is contained in:
Rick Chao 2020-09-16 12:43:43 -07:00 committed by TensorFlower Gardener
parent 8aefe626a9
commit 96a993a650
10 changed files with 50 additions and 107 deletions

View File

@ -147,7 +147,7 @@ py_library(
":mirrored_strategy",
":one_device_strategy",
":sharded_variable",
"//tensorflow/python/distribute/client:parameter_server_client",
"//tensorflow/python/distribute/client",
"//tensorflow/python/distribute/experimental",
],
)

View File

@ -7,23 +7,13 @@ package(
exports_files(["LICENSE"])
py_library(
name = "parameter_server_client",
srcs = ["parameter_server_client.py"],
srcs_version = "PY2AND3",
deps = [
":client",
":utils",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
],
)
py_library(
name = "client",
srcs = ["client.py"],
srcs_version = "PY2AND3",
deps = [
":metric_utils",
":utils",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:func_graph",
@ -68,7 +58,6 @@ tf_py_test(
tags = ["no_oss"], # TODO(b/162119374)
deps = [
":client",
":parameter_server_client",
"//tensorflow/python:check_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@ -94,7 +83,6 @@ tf_py_test(
shard_count = 2,
tags = ["no_oss"], # TODO(b/162119374)
deps = [
":parameter_server_client",
":remote_eager_lib",
":utils",
"//tensorflow/python:dtypes",
@ -102,7 +90,9 @@ tf_py_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:multi_process_runner",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:sharded_variable",
"//tensorflow/python/distribute/client",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",

View File

@ -384,7 +384,7 @@ class _CoordinatedClosureQueue(object):
if _CLOSURE_QUEUE_MAX_SIZE <= 0:
logging.warning(
"In ParameterServerClient, creating an infinite closure queue can "
"In a `Client`, creating an infinite closure queue can "
"consume a significant amount of memory and even lead to OOM.")
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
self._error = None
@ -846,9 +846,7 @@ class Client(object):
functions to be executed, and fetch the results of the functions.
Currently, `Client` is not supported to be used in a standalone manner.
It should be used in conjunction with `ParameterServerStrategyV2`. The
recommended way of using the combination is through a `ParameterServerClient`
object. Please see `ParameterServerClient` for more information.
It should be used in conjunction with `ParameterServerStrategyV2`.
This is currently under development, and the API as well as implementation
is subject to changes.

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multi-process runner tests for parameter_server_client.py."""
"""Multi-process runner tests for `Client` with `ParameterServerStrategyV2`."""
from __future__ import absolute_import
from __future__ import division
@ -23,8 +23,8 @@ from absl import logging
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute.client import client
from tensorflow.python.distribute.client import parameter_server_client
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client as client_lib
from tensorflow.python.distribute.client import utils
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.eager import def_function
@ -51,9 +51,11 @@ class ClientMprTest(test.TestCase):
cluster_resolver = TFConfigClusterResolver()
if cluster_resolver.task_type != "chief":
utils.start_server(cluster_resolver, "grpc")
ps_client = parameter_server_client.ParameterServerClient(
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver)
with ps_client._strategy.scope():
ps_client = client_lib.Client(strategy)
with strategy.scope():
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
@def_function.function
@ -77,7 +79,7 @@ class ClientMprTest(test.TestCase):
while ps_client.cluster._closure_queue._error is None:
time.sleep(1)
ps_client.schedule(worker_fn)
except client.ParameterServerFailureError:
except client_lib.ParameterServerFailureError:
# The following verifies that after PS fails, continue executing
# functions on workers should fail and indicate it's PS failure.
for worker_id in range(3):
@ -87,7 +89,7 @@ class ClientMprTest(test.TestCase):
# failure.
worker_fn()
except Exception as e: # pylint: disable=broad-except
if client._is_ps_failure(e):
if client_lib._is_ps_failure(e):
if worker_id < 2:
continue
logging.info("_test_translate_ps_failure_error ends properly.")

View File

@ -1,55 +0,0 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Parameter server client module.
This is currently under development and the API is subject to change.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client
class ParameterServerClient(client.Client):
"""A client that uses `ParameterServerStrategy` to distribute tasks.
Parameter server training refers to the distributed training architecture
that requires two jobs in the cluster: workers and parameter servers. The
variables and updates to those variables are assigned on the parameter
servers' tasks, and the actual computation intensive operations are assigned
on worker tasks. In TF2, parameter server training only starts up one
client process, to drive and coordinate the workers and parameter servers.
This is referred to as single-client architecture, as opposed to multi-client
approach which is seen more often in traditional TensorFlow distributed
training, including `tf.estimator.Estimator` and `tf.keras` with
`tf.distribute.experimental.MultiWorkerMirroredStrategy`.
`ParameterServerClient` is a `Client` that uses `ParameterServerStrategy` as
the underlying strategy to distribute, and is the starting point of parameter
server training/evaluation.
If 'TF_CONFIG' environment variable is used, provide a
`TFConfigClusterResolver` to detect configurations for multi-worker training.
"""
def __init__(self, cluster_resolver, variable_partitioner=None):
super(ParameterServerClient, self).__init__(
parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver, variable_partitioner))

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for parameter_server_client.py."""
"""Tests for `Client` when used together with `ParameterServerStrategyV2."""
from __future__ import absolute_import
from __future__ import division
@ -26,8 +26,8 @@ from absl import logging
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute.client import client
from tensorflow.python.distribute.client import parameter_server_client
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client as client_lib
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
@ -93,7 +93,9 @@ def make_client(num_workers, num_ps):
]
cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def), rpc_layer="grpc")
return parameter_server_client.ParameterServerClient(cluster_resolver)
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver)
return client_lib.Client(strategy)
class ParameterServerClientTest(TestCaseWithErrorReportingThread):
@ -102,13 +104,14 @@ class ParameterServerClientTest(TestCaseWithErrorReportingThread):
def setUpClass(cls):
super(ParameterServerClientTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy
def testBasic(self):
self.client._strategy.extended._variable_count = 0
with self.client.strategy.scope():
self.strategy.extended._variable_count = 0
with self.strategy.scope():
v1 = variables.Variable(initial_value=0.0)
v2 = variables.Variable(initial_value=1.0)
self.assertEqual(self.client._strategy.extended._variable_count, 2)
self.assertEqual(self.strategy.extended._variable_count, 2)
@def_function.function
def worker_fn():
@ -139,7 +142,7 @@ class ParameterServerClientTest(TestCaseWithErrorReportingThread):
def input_fn():
return dataset_ops.DatasetV2.range(1, 2)
with self.client.strategy.scope():
with self.strategy.scope():
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
@def_function.function
@ -163,7 +166,7 @@ class ParameterServerClientTest(TestCaseWithErrorReportingThread):
def input_fn():
return dataset_ops.DatasetV2.from_tensor_slices([2] * 10)
with self.client.strategy.scope():
with self.strategy.scope():
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
# TODO(yuefengz): the following tf.function has a return value which is None
@ -268,7 +271,7 @@ class ParameterServerClientTest(TestCaseWithErrorReportingThread):
class LimitedClosureQueueSizeBasicTest(ParameterServerClientTest):
"""Test basic functionality works with explicit maximum closure queue size.
Execute the same set of test cases as in ParameterServerClientTest, with an
Execute the same set of test cases as in `ParameterServerClientTest`, with an
explicit size limit for the closure queue. Note that even when the queue size
is set to infinite, there is still a maximum practical size (depends on host
memory limit) that might cause the queue.put operations to be blocking when
@ -279,8 +282,9 @@ class LimitedClosureQueueSizeBasicTest(ParameterServerClientTest):
@classmethod
def setUpClass(cls):
super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
client._CLOSURE_QUEUE_MAX_SIZE = 2
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2
cls.client = make_client(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy
class ErrorReportingTest(TestCaseWithErrorReportingThread):
@ -289,8 +293,9 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
def setUpClass(cls):
super(ErrorReportingTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy
with cls.client.strategy.scope():
with cls.strategy.scope():
cls.iteration = variables.Variable(initial_value=0.0)
@def_function.function
@ -373,10 +378,10 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
self.client.join()
result = self.client.schedule(func, args=(aborted,))
with self.assertRaises(client.InputError):
with self.assertRaises(client_lib.InputError):
result.fetch()
with self.assertRaises(client.InputError):
with self.assertRaises(client_lib.InputError):
self.client.join()
def testCancellation(self):
@ -388,7 +393,7 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread):
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
with self.assertRaises(client.FunctionRetryableError):
with self.assertRaises(client_lib.FunctionRetryableError):
long_function.fetch()
for _ in range(3):
@ -406,8 +411,9 @@ class LimitedClosureQueueErrorTest(ErrorReportingTest):
@classmethod
def setUpClass(cls):
super(LimitedClosureQueueErrorTest, cls).setUpClass()
client._CLOSURE_QUEUE_MAX_SIZE = 2
client_lib._CLOSURE_QUEUE_MAX_SIZE = 2
cls.client = make_client(num_workers=3, num_ps=2)
cls.strategy = cls.client.strategy
with cls.client.strategy.scope():
cls.iteration = variables.Variable(initial_value=0.0)
@ -419,10 +425,11 @@ class StrategyRunTest(test.TestCase):
def setUpClass(cls):
super(StrategyRunTest, cls).setUpClass()
cls.client = make_client(num_workers=1, num_ps=1)
cls.strategy = cls.client.strategy
def testStrategyRun(self):
self.assertFalse(distribution_strategy_context.in_cross_replica_context())
with self.client._strategy.scope():
with self.strategy.scope():
self.assertTrue(distribution_strategy_context.in_cross_replica_context())
v = variables.Variable(initial_value=1)
@ -435,11 +442,11 @@ class StrategyRunTest(test.TestCase):
distribution_strategy_context.in_cross_replica_context())
return input_tensor + v
return self.client._strategy.run(replica_fn, args=(input_tensor,))
return self.strategy.run(replica_fn, args=(input_tensor,))
# Asserting scheduling in scope has the expected behavior.
result = self.client.schedule(worker_fn, args=(constant_op.constant(3),))
self.assertIsInstance(result, client.RemoteValue)
self.assertIsInstance(result, client_lib.RemoteValue)
self.assertEqual(result.fetch(), 4)
# Asserting scheduling out of scope has the expected behavior.

View File

@ -42,10 +42,8 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
"""An asynchronous multi-worker parameter server tf.distribute strategy.
Currently, `ParameterServerStrategyV2` is not supported to be used as a
standalone tf.distribute strategy. It must be used in conjunction with
`Client`. The recommended way of using the combination is through a
`ParameterServerClient` object. Please see `Client` and
`ParameterServerClient` for more information.
standalone tf.distribute strategy. It should be used in conjunction with
`Client`. Please see `Client` for more information.
This is currently under development, and the API as well as implementation
is subject to changes.

View File

@ -854,8 +854,9 @@ py_test(
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:sharded_variable",
"//tensorflow/python/distribute/client:parameter_server_client",
"//tensorflow/python/distribute/client",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:def_function",

View File

@ -26,7 +26,8 @@ from tensorflow.python import keras
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute.client import parameter_server_client
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client as client_lib
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
@ -51,7 +52,8 @@ def make_client(num_workers, num_ps):
]
cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def), rpc_layer="grpc")
return parameter_server_client.ParameterServerClient(cluster_resolver)
return client_lib.Client(
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
class KPLTest(test.TestCase):

View File

@ -149,8 +149,8 @@ COMMON_PIP_DEPS = [
"//tensorflow/tools/common:public_api",
"//tensorflow/tools/common:test_module1",
"//tensorflow/tools/common:traverse",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute/client:client",
"//tensorflow/python/distribute/client:parameter_server_client",
"//tensorflow/python/distribute/client:remote_eager_lib",
"//tensorflow/python/distribute/client:metric_utils",
]