PSv2: Move TF2 parameter server training main library code into OSS.

PiperOrigin-RevId: 324103111
Change-Id: Ic6c9eb64e851873c1ee6e768f645b9367528470d
This commit is contained in:
Rick Chao 2020-07-30 16:06:49 -07:00 committed by TensorFlower Gardener
parent b92c68721b
commit a4725be4a7
10 changed files with 2518 additions and 0 deletions

View File

@ -1760,3 +1760,21 @@ distribute_py_test(
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "parameter_server_strategy_v2",
srcs = ["parameter_server_strategy_v2.py"],
srcs_version = "PY3",
deps = [
":parameter_server_strategy",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:sharded_variable",
"//tensorflow/python/distribute:values",
],
)

View File

@ -0,0 +1,102 @@
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
py_library(
name = "parameter_server_client",
srcs = ["parameter_server_client.py"],
srcs_version = "PY3",
deps = [
":client",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
],
)
py_library(
name = "client",
srcs = ["client.py"],
srcs_version = "PY3",
deps = [
":metric_utils",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:func_graph",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training_server_lib",
"//tensorflow/python:util",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:executor",
"//tensorflow/python/eager:function",
"//tensorflow/python/eager:remote",
"@absl_py//absl/logging",
"@six_archive//:six",
],
)
py_test(
name = "client_test",
size = "small",
srcs = ["client_test.py"],
python_version = "PY3",
shard_count = 12,
deps = [
":client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training_lib",
"//tensorflow/python:util",
"//tensorflow/python/eager:def_function",
"@absl_py//absl/logging",
],
)
py_test(
name = "parameter_server_client_test",
srcs = ["parameter_server_client_test.py"],
python_version = "PY3",
shard_count = 14,
tags = ["no_oss"], # TODO(b/162119374)
deps = [
":parameter_server_client",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:init_ops_v2",
"//tensorflow/python:training_server_lib",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:sharded_variable",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
],
)
py_library(
name = "metric_utils",
srcs = ["metric_utils.py"],
srcs_version = "PY3",
deps = [
"//tensorflow/python/eager:monitoring",
],
)
py_test(
name = "metric_utils_test",
srcs = ["metric_utils_test.py"],
python_version = "PY3",
deps = [
":client",
":metric_utils",
"//tensorflow/python:training_server_lib",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:test",
],
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,388 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for client.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import threading
import time
from absl import logging
from tensorflow.python.distribute.client import client
from tensorflow.python.eager import def_function
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
from tensorflow.python.util import nest
class CoordinatedClosureQueueTest(test.TestCase):
def testBasic(self):
queue = client._CoordinatedClosureQueue()
closure1 = self._create_closure()
queue.put(closure1)
self.assertIs(closure1, queue.get())
self.assertFalse(queue.done())
queue.put_back(closure1)
self.assertEqual(closure1, queue.get())
queue.mark_finished(closure1)
self.assertTrue(queue.done())
queue.wait()
def testProcessAtLeaseOnce(self):
closure_queue = client._CoordinatedClosureQueue()
labels = ['A', 'B', 'C', 'D', 'E']
processed_count = collections.defaultdict(int)
coord = coordinator.Coordinator(clean_stop_exception_types=[])
def process_queue():
with coord.stop_on_exception():
has_been_put_back = False
while True:
closure = closure_queue.get(timeout=30)
if closure is None:
break
if not has_been_put_back:
has_been_put_back = True
closure_queue.put_back(closure)
continue
closure._function()
closure_queue.mark_finished(closure)
def get_func(label):
def func():
logging.info('Label: %s, before waiting 3 sec', label)
time.sleep(3)
processed_count[label] += 1
logging.info('Label: %s, after waiting 3 sec', label)
return func
for label in labels:
closure_queue.put(client.Closure(get_func(label)))
t1 = threading.Thread(target=process_queue, daemon=True)
t1.start()
t2 = threading.Thread(target=process_queue, daemon=True)
t2.start()
# Make sure multiple wait() calls are fine.
closure_queue.wait()
closure_queue.wait()
closure_queue.wait()
closure_queue.wait()
self.assertEqual(processed_count, collections.Counter(labels))
coord.join([t1, t2])
def testNotifyBeforeWait(self):
closure_queue = client._CoordinatedClosureQueue()
def func():
logging.info('func running')
coord = coordinator.Coordinator(clean_stop_exception_types=[])
def process_queue():
with coord.stop_on_exception():
closure = closure_queue.get()
closure_queue.mark_finished(closure)
closure_queue.put(client.Closure(func))
t = threading.Thread(target=process_queue)
t.start()
coord.join([t])
# This test asserts that waiting at the time the function has been processed
# doesn't time out.
closure_queue.wait()
def testWaitRaiseErrorAfterMarkFailure(self):
closure_queue = client._CoordinatedClosureQueue()
closure_queue.put(self._create_closure())
closure = closure_queue.get()
wait_finish_event = threading.Event()
coord = coordinator.Coordinator(clean_stop_exception_types=[])
# Using a thread to verify that closure_queue.wait() will not return until
# all inflight closures are finished.
def mark_finished_fn():
with coord.stop_on_exception():
self.assertFalse(wait_finish_event.is_set())
try:
raise ValueError('Some error.')
except ValueError as e:
closure_queue.mark_failed(e, closure)
wait_finish_event.wait()
t = threading.Thread(target=mark_finished_fn)
t.start()
with self.assertRaises(ValueError):
closure_queue.wait()
wait_finish_event.set()
coord.join([t])
self.assertTrue(closure_queue.done())
def _create_closure(self):
@def_function.function()
def some_function():
return 1.0
return client.Closure(some_function)
def _put_two_closures_and_get_one(self):
closure_queue = client._CoordinatedClosureQueue()
closure1 = self._create_closure()
closure_queue.put(closure1)
closure2 = self._create_closure()
closure_queue.put(closure2)
closure_got = closure_queue.get() # returns closure1
self.assertIs(closure_got, closure1)
self.assertIsNot(closure_got, closure2)
return closure_queue, closure1, closure2
def testPutRaiseError(self):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue.mark_failed(ValueError(), closure1)
with self.assertRaises(ValueError):
closure_queue.put(self._create_closure())
self.assertTrue(closure_queue.done())
with self.assertRaisesRegex(
client.FunctionRetryableError,
'The corresponding function is cancelled. Please reschedule the '
'function.'):
closure2._fetch_output_remote_values()
# The error is cleared.
closure_queue.put(self._create_closure())
def testWaitRaiseError(self):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue.mark_failed(ValueError(), closure1)
with self.assertRaises(ValueError):
closure_queue.wait()
self.assertTrue(closure_queue.done())
with self.assertRaisesRegex(
client.FunctionRetryableError,
'The corresponding function is cancelled. Please reschedule the '
'function.'):
closure2._fetch_output_remote_values()
# The error is cleared.
closure_queue.wait()
def testDoneRaiseError(self):
closure_queue, closure1, _ = self._put_two_closures_and_get_one()
closure_queue.get()
self.assertFalse(closure_queue.done())
closure_queue.mark_failed(ValueError(), closure1)
with self.assertRaises(ValueError):
closure_queue.done()
def _test_error_reporting_and_cancel_flow(self, call_wait):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue.put(self._create_closure())
closure_queue.get()
# At this moment, there are two inflight, one in queue.
self.assertEqual(closure_queue._inflight_closure_count, 2)
# Simulating closure1 fails.
try:
raise ValueError('Some error.')
except ValueError as e:
nest.map_structure(lambda x: x._set_error(e),
closure1._output_remote_values)
self.assertEqual(closure_queue._error_generation, 0) # pylint: disable=g-assert-in-except
closure_queue.mark_failed(e, closure1)
self.assertEqual(closure_queue._error_generation, 1)
# At this moment, there are one inflight, nothing
# in queue (because the ones in queue should have been removed and
# cancelled).
self.assertTrue(closure_queue._queue.empty())
# Doesn't include out of generation closures.
self.assertEqual(closure_queue._inflight_closure_count, 1)
coord = coordinator.Coordinator(clean_stop_exception_types=[])
closure3 = self._create_closure()
with self.assertRaises(ValueError):
# Verifying `wait()` or `put()` raises even if one closure is in
# flight.
if call_wait:
closure_queue.wait()
else:
closure_queue.put(closure3)
# At this moment, there is one inflight, nothing in queue.
self.assertTrue(closure_queue._queue.empty())
self.assertEqual(closure_queue._inflight_closure_count, 1)
# This asserts that closure1 has errored.
with self.assertRaisesRegex(ValueError, 'Some error.'):
closure1._fetch_output_remote_values()
# The following asserts that closure3 should have been cancelled.
if not call_wait:
with self.assertRaisesRegex(
client.FunctionRetryableError,
'The corresponding function is cancelled. Please reschedule the '
'function.'):
closure3._fetch_output_remote_values()
# Closure2 is inflight, so it shouldn't be ready.
self.assertEqual(closure2._output_remote_values._status,
client._RemoteValueStatus.NOT_READY)
# And `wait` should block because closure2 is not back yet.
self.assertFalse(closure_queue.wait(timeout=20))
# Now let's assume that closure2 isn't successful due to worker preemption,
# and now it's attempted to be put back, but ends up getting cancelled.
self.assertEqual(closure2._error_generation, 0)
self.assertEqual(closure_queue._error_generation, 1)
closure_queue.put_back(closure2)
with self.assertRaisesRegex(
client.FunctionRetryableError,
'The corresponding function is cancelled. Please reschedule the '
'function.'):
closure2._fetch_output_remote_values()
# At this moment, there is nothing inflight, and the queue is also empty
# (because closure2 should not be added back to the queue).
self.assertTrue(closure_queue._queue.empty())
self.assertEqual(closure_queue._inflight_closure_count, 0)
closure4 = self._create_closure()
e = threading.Event()
def get_fn():
with coord.stop_on_exception():
# This should end up getting closure4, not closure2, because closure2
# has been cancelled and should not be got.
closure_got = closure_queue.get()
e.set()
self.assertEqual(closure_got._error_generation, 1)
self.assertEqual(closure_queue._error_generation, 1)
self.assertIs(closure4, closure_got)
self.assertIsNot(closure2, closure_got)
t = threading.Thread(target=get_fn)
t.start()
time.sleep(10)
# Make sure `closure_got = closure_queue.get()` is unblocked as a result of
# `closure_queue.put(closure4)`.
self.assertFalse(e.is_set())
closure_queue.put(closure4)
self.assertTrue(e.wait())
coord.join([t])
self.assertEqual(closure_queue._inflight_closure_count, 1)
closure_queue.mark_finished(closure4)
# The queue is now cleared and nothing inflight.
self.assertEqual(closure_queue._inflight_closure_count, 0)
closure_queue.wait()
def testWaitRaiseErrorAfterAnErrorIsReported(self):
self._test_error_reporting_and_cancel_flow(call_wait=True)
def testPutRaiseErrorAfterAnErrorIsReported(self):
self._test_error_reporting_and_cancel_flow(call_wait=False)
def testStateIsRestoredAfterJoinIsCalled(self):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue.get()
self.assertEqual(closure_queue._inflight_closure_count, 2)
closure_queue.mark_failed(ValueError('test error'), closure1)
with self.assertRaises(ValueError):
closure_queue.put(self._create_closure())
closure_queue.mark_failed(ValueError('test error'), closure2)
# closure2's error is previous generation so should not raise at this
# following put, and _error should have been cleared.
self.assertIsNone(closure_queue._error)
closure_queue.put(self._create_closure())
self.assertIsNone(closure_queue._error)
def testStateIsRestoredAfterJoinIsCalled_WaitShouldReturn(self):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue.put(self._create_closure())
closure_queue.get() # got closure2
self.assertFalse(closure_queue._queue.empty()) # still has closure3
self.assertEqual(closure_queue._inflight_closure_count, 2) # closure1,2
closure_queue.mark_failed(ValueError('test error'), closure1)
self.assertTrue(closure_queue._queue.empty()) # closure3 cancelled
self.assertEqual(closure_queue._inflight_closure_count, 1)
with self.assertRaises(ValueError):
closure_queue.wait() # reports error from closure1
# `wait` should block because closure2 is not back yet, even if closure2
# was sent inflight before the error.
self.assertFalse(closure_queue.wait(timeout=20))
self.assertEqual(closure_queue._inflight_closure_count, 1)
closure_queue.mark_finished(closure2)
closure_queue.wait() # wait should pass immediately
self.assertEqual(closure_queue._inflight_closure_count, 0)
def testThreadSafey(self):
thread_count = 10
queue = client._CoordinatedClosureQueue()
# Each thread performs 20 queue actions: 10 are `put_back` and 10 are
# `mark_finished`.
action_count = 20
def func():
for i in range(action_count):
closure = queue.get()
if i % 2 == 0:
queue.put_back(closure)
else:
queue.mark_finished(closure)
threads = [threading.Thread(target=func) for i in range(thread_count)]
for t in threads:
t.start()
for _ in range(thread_count * action_count // 2):
queue.put(self._create_closure())
queue.wait()
self.assertTrue(queue.done())
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,79 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Metrics collecting utilities for single client training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from tensorflow.python.eager import monitoring
from tensorflow.python.util import tf_contextlib
enable_metrics = False
# Time in seconds to bucket the distribution of execution time. Range from
# 0.001s (i.e., 1ms) to 1000s.
_time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6)
_function_tracing_sampler = monitoring.Sampler(
'/tensorflow/api/ps_strategy/client/function_tracing', _time_buckets,
'Sampler to track the time (in seconds) for tracing functions.')
_closure_execution_sampler = monitoring.Sampler(
'/tensorflow/api/ps_strategy/client/closure_execution', _time_buckets,
'Sampler to track the time (in seconds) for executing closures.')
_remote_value_fetch_sampler = monitoring.Sampler(
'/tensorflow/api/ps_strategy/client/remote_value_fetch', _time_buckets,
'Sampler to track the time (in seconds) for fetching remote_value.')
_METRICS_MAPPING = {
'function_tracing': _function_tracing_sampler,
'closure_execution': _closure_execution_sampler,
'remote_value_fetch': _remote_value_fetch_sampler
}
@tf_contextlib.contextmanager
def monitored_timer(metric_name, state_tracker=None):
"""Monitor the execution time and collect it into the specified metric."""
if not enable_metrics:
yield
else:
start_time = time.time()
start_state = state_tracker() if state_tracker else None
yield
duration_sec = time.time() - start_time
# If a state_checker is provided, record the metric only if the end state is
# different from the start state.
if state_tracker is None or state_tracker() != start_state:
metric = _METRICS_MAPPING[metric_name]
metric.get_cell().add(duration_sec)
def get_metric_summary(metric_name):
"""Get summary for the specified metric."""
metric = _METRICS_MAPPING[metric_name]
histogram_proto = metric.get_cell().value()
ret = dict()
ret['min'] = histogram_proto.min
ret['max'] = histogram_proto.max
ret['num'] = histogram_proto.num
ret['sum'] = histogram_proto.sum
# TODO(haoyuzhang): consider reporting the distribution in buckets.
return ret

View File

@ -0,0 +1,69 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metrics collecting in client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute.client import client
from tensorflow.python.distribute.client import metric_utils
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.training.server_lib import ClusterSpec
class MetricUtilsTest(test.TestCase):
def testClientMetrics(self):
metric_utils.enable_metrics = True
cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=1, num_ps=1, rpc_layer='grpc')
cluster_def['chief'] = [
'localhost:%d' % multi_worker_test_base.pick_unused_port()
]
cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def), rpc_layer='grpc')
cluster = client.Cluster(cluster_resolver)
@def_function.function
def func():
time.sleep(0.5)
return 3
result = cluster.schedule(func, args=None, kwargs=None)
result = cluster.schedule(func, args=None, kwargs=None)
cluster.join()
self.assertEqual(result._get_value().numpy(), 3)
# Tracing, closure execution, and remote_value fetching should be executed
# exactly once for running this function.
metric_tracing = metric_utils.get_metric_summary('function_tracing')
self.assertEqual(metric_tracing['num'], 1)
# Tracing time should be longer than the sleep time in Python function.
self.assertGreater(metric_tracing['sum'], 0.5)
metric_closure = metric_utils.get_metric_summary('closure_execution')
self.assertEqual(metric_closure['num'], 2)
metric_remote_value = metric_utils.get_metric_summary('remote_value_fetch')
self.assertEqual(metric_remote_value['num'], 2)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,55 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Parameter server client module.
This is currently under development and the API is subject to change.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import client
class ParameterServerClient(client.Client):
"""A client that uses `ParameterServerStrategy` to distribute tasks.
Parameter server training refers to the distributed training architecture
that requires two jobs in the cluster: workers and parameter servers. The
variables and updates to those variables are assigned on the parameter
servers' tasks, and the actual computation intensive operations are assigned
on worker tasks. In TF2, parameter server training only starts up one
client process, to drive and coordinate the workers and parameter servers.
This is referred to as single-client architecture, as opposed to multi-client
approach which is seen more often in traditional TensorFlow distributed
training, including `tf.estimator.Estimator` and `tf.keras` with
`tf.distribute.experimental.MultiWorkerMirroredStrategy`.
`ParameterServerClient` is a `Client` that uses `ParameterServerStrategy` as
the underlying strategy to distribute, and is the starting point of parameter
server training/evaluation.
If 'TF_CONFIG' environment variable is used, provide a
`TFConfigClusterResolver` to detect configurations for multi-worker training.
"""
def __init__(self, cluster_resolver):
super(ParameterServerClient, self).__init__(
parameter_server_strategy_v2.ParameterServerStrategyV2(
cluster_resolver))

View File

@ -0,0 +1,405 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for parameter_server_client.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.distribute.client import client
from tensorflow.python.distribute.client import parameter_server_client
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.server_lib import ClusterSpec
def make_client(num_workers, num_ps):
# TODO(rchao): Test the internal rpc_layer version.
cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
cluster_def["chief"] = [
"localhost:%d" % multi_worker_test_base.pick_unused_port()
]
cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def), rpc_layer="grpc")
return parameter_server_client.ParameterServerClient(cluster_resolver)
class ParameterServerClientTest(test.TestCase):
@classmethod
def setUpClass(cls):
super(ParameterServerClientTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2)
def testBasic(self):
self.client._strategy.extended._variable_count = 0
with self.client.context():
v1 = variables.Variable(initial_value=0.0)
v2 = variables.Variable(initial_value=1.0)
self.assertEqual(self.client._strategy.extended._variable_count, 2)
@def_function.function
def worker_fn():
v1.assign_add(0.1)
v2.assign_sub(0.2)
return v1.read_value() / v2.read_value()
results = self.client.schedule(worker_fn)
logging.info("Results of experimental_run_v2: %f",
self.client.fetch(results))
self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6)
self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6)
def testFnReturnNestedValues(self):
x = constant_op.constant(1)
@def_function.function
def f():
return x + 1, (x + 2, x + 3), [x + 4], {"v": x}
got = self.client.schedule(f)
want = 2, (3, 4), [5], {"v": 1}
self.assertEqual(self.client.fetch(got), want)
def testInputFunction(self):
def input_fn():
return dataset_ops.DatasetV2.range(1, 2)
with self.client.context():
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
@def_function.function
def worker_fn(iterator):
x = next(iterator)
v.assign_add(x)
return x
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
result = self.client.fetch(result)
self.assertEqual(result, (1,))
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
result = self.client.fetch(result)
self.assertEqual(result, (1,))
self.assertAlmostEqual(v.read_value().numpy(), 2, delta=1e-6)
def testAsyncScheduleAndJoin(self):
def input_fn():
return dataset_ops.DatasetV2.from_tensor_slices([2] * 10)
with self.client.context():
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
# TODO(yuefengz): the following tf.function has a return value which is None
# in its structured_outputs.
@def_function.function
def worker_fn(iterator):
x = next(iterator)
v.assign_add(x)
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
iterator = iter(distributed_dataset)
# Verifying joining without any scheduling doesn't hang.
self.client.join()
self.assertEqual(v.read_value().numpy(), 0)
for _ in range(5):
self.client.schedule(worker_fn, args=(iterator,))
self.client.join()
# With 5 addition it should be 2*5 = 10.
self.assertEqual(v.read_value().numpy(), 10)
for _ in range(5):
self.client.schedule(worker_fn, args=(iterator,))
# Verifying multiple join is fine.
self.client.join()
self.client.join()
self.client.join()
self.assertTrue(self.client.done())
# Likewise, it's now 20.
self.assertEqual(v.read_value().numpy(), 20)
def testInputFunctionWithMap(self):
self._map_fn_tracing_count = 0
def input_fn():
def map_fn(x):
self._map_fn_tracing_count += 1
return x + 10
return dataset_ops.DatasetV2.range(0, 10).map(map_fn)
@def_function.function
def worker_fn(iterator):
return next(iterator)
distributed_dataset = (
self.client.create_per_worker_dataset(input_fn))
result = self.client.schedule(
worker_fn, args=(iter(distributed_dataset),))
self.assertEqual(result.fetch(), (10,))
self.assertEqual(self._map_fn_tracing_count, 1)
def testInputFunctionCreateVariables(self):
def input_fn():
v = variables.Variable(initial_value=0.0)
return v.read_value()
with self.assertRaises(ValueError):
self.client.create_per_worker_dataset(input_fn)
class LimitedClosureQueueSizeBasicTest(ParameterServerClientTest):
"""Test basic functionality works with explicit maximum closure queue size.
Execute the same set of test cases as in ParameterServerClientTest, with an
explicit size limit for the closure queue. Note that even when the queue size
is set to infinite, there is still a maximum practical size (depends on host
memory limit) that might cause the queue.put operations to be blocking when
scheduling a large number of closures on a big cluster. These tests make sure
that the client does not run into deadlocks in such scenario.
"""
@classmethod
def setUpClass(cls):
super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
client._CLOSURE_QUEUE_MAX_SIZE = 2
cls.client = make_client(num_workers=3, num_ps=2)
class VariablePartitioningScopeTest(test.TestCase):
@classmethod
def setUpClass(cls):
super(VariablePartitioningScopeTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2)
def testBasic(self):
with self.client.context():
with self.client.experimental_variable_partitioning_scope():
init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
v1 = variables.Variable(
initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64),
shape=(5, 2),
dtype=dtypes.int64)
init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5])
v2 = variables.Variable(
initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64),
shape=(6, 1),
dtype=dtypes.int64)
self.assertIsInstance(v1, sharded_variable.ShardedVariable)
self.assertLen(v1.variables, 2)
self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0")
self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1")
self.assertAllEqual(v1.variables[0].read_value().numpy(),
[[0, 1], [2, 3], [4, 5]])
self.assertAllEqual(v1.variables[1].read_value().numpy(), [[6, 7], [8, 9]])
self.assertIsInstance(v2, sharded_variable.ShardedVariable)
self.assertLen(v2.variables, 2)
self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0")
self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1")
self.assertAllEqual(v2.variables[0].read_value().numpy(), [[0], [1], [2]])
self.assertAllEqual(v2.variables[1].read_value().numpy(), [[3], [4], [5]])
def testSurplusPS(self):
with self.client.context():
with self.client.experimental_variable_partitioning_scope():
initializer = init_ops_v2.Constant([0])
v = variables.Variable(
initial_value=lambda: initializer(shape=(1,), dtype=dtypes.int64),
shape=(1,),
dtype=dtypes.int64)
self.assertIsInstance(v, sharded_variable.ShardedVariable)
self.assertLen(v.variables, 1)
self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
self.assertAllEqual(v.variables[0].read_value().numpy(), [0])
def testInvalidArgument(self):
with self.assertRaisesRegex(ValueError, "initial_value"):
with self.client.experimental_variable_partitioning_scope():
variables.Variable(initial_value=[0, 1, 2], shape=(3,))
with self.assertRaisesRegex(ValueError, "shape"):
with self.client.experimental_variable_partitioning_scope():
initializer = init_ops_v2.Constant([0, 1, 2])
variables.Variable(
initial_value=lambda: initializer(shape=(3,), dtype=dtypes.int64),
dtype=dtypes.int64)
def testPerWorkerValue(self):
var_shape = tuple()
var_dtype = dtypes.float32
var_name = "var"
def create_var():
var = variables.Variable(
initial_value=0.0, dtype=var_dtype, name=var_name)
self.assertIn("worker", var.device)
return var
worker_local_var = self.client._create_per_worker_resources(create_var)
# The following is a workaround to allow `worker_local_var` to be passed in
# as args to the `client.schedule` method which requires tensor specs to
# trace tf.function but _create_worker_resources' return values don't have
# tensor specs. We can get rid of this workaround once
# _create_worker_resources is able to infer the tensor spec of the return
# value of the function passed in. See b/154675763.
for var in worker_local_var._values:
var._set_type_spec(tensor_spec.TensorSpec(var_shape, var_dtype, var_name))
def worker_fn(var):
var.assign_add(1.0)
for _ in range(10):
# Which slice of `worker_local_var` will be used will depend on which
# worker the `worker_fn` gets scheduled on.
self.client.schedule(worker_fn, args=(worker_local_var,))
self.client.join()
var_sum = sum(self.client.fetch(worker_local_var._values))
self.assertEqual(var_sum, 10.0)
class ErrorReportingTest(test.TestCase):
@classmethod
def setUpClass(cls):
super(ErrorReportingTest, cls).setUpClass()
cls.client = make_client(num_workers=3, num_ps=2)
with cls.client.context():
cls.iteration = variables.Variable(initial_value=0.0)
@def_function.function
def _normal_function(self):
x = random_ops.random_uniform((2, 10))
y = random_ops.random_uniform((10, 2))
self.iteration.assign_add(1.0)
return math_ops.reduce_mean(math_ops.matmul(x, y))
@def_function.function
def _error_function(self):
x = random_ops.random_uniform((2, 10))
y = random_ops.random_uniform((10, 2))
check_ops.assert_non_positive_v2(math_ops.reduce_sum(math_ops.matmul(x, y)))
self.iteration.assign_add(1.0)
return self.iteration
def testJoinRaiseError(self):
for _ in range(3):
self.client.schedule(self._normal_function)
self.client.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
def testScheduleRaiseError(self):
for _ in range(3):
self.client.schedule(self._normal_function)
self.client.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
while True:
self.client.schedule(self._normal_function)
def testErrorWillbeCleared(self):
self.skipTest("b/157597579")
self.client.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
for _ in range(3):
self.client.schedule(self._normal_function)
self.client.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
def testFutureReturnError(self):
result = self.client.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
result.fetch()
# Clear the error.
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
def testInputError(self):
aborted = self.client.schedule(self._error_function)
@def_function.function
def func(x):
return x + 1.0
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
result = self.client.schedule(func, args=(aborted,))
with self.assertRaises(client.InputError):
result.fetch()
with self.assertRaises(client.InputError):
self.client.join()
class LimitedClosureQueueErrorTest(ErrorReportingTest):
"""Test error reporting works with explicit maximum closure queue size.
Execute the same set of test cases as in ErrorReportingTest, with an explicit
size limit for the closure queue.
"""
@classmethod
def setUpClass(cls):
super(LimitedClosureQueueErrorTest, cls).setUpClass()
client._CLOSURE_QUEUE_MAX_SIZE = 2
cls.client = make_client(num_workers=3, num_ps=2)
with cls.client.context():
cls.iteration = variables.Variable(initial_value=0.0)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,202 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Parameter server strategy V2 class.
This is currently under development and the API is subject to change.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import tf_contextlib
# pylint: disable=protected-access
class ParameterServerStrategyV2(distribute_lib.Strategy):
"""An asynchronous multi-worker parameter server tf.distribute strategy.
Currently, `ParameterServerStrategyV2` is not supported to be used as a
standalone tf.distribute strategy. It must be used in conjunction with
`Client`. The recommended way of using the combination is through a
`ParameterServerClient` object. Please see `Client` and
`ParameterServerClient` for more information.
This is currently under development, and the API as well as implementation
is subject to changes.
"""
def __init__(self, cluster_resolver):
"""Initializes the V2 parameter server strategy.
Args:
cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
object.
"""
self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver)
self._cluster_resolver = cluster_resolver
self._verify_args_and_config(cluster_resolver)
logging.info(
"ParameterServerStrategyV2 is initialized with cluster_spec: "
"%s", cluster_resolver.cluster_spec())
super(ParameterServerStrategyV2, self).__init__(self._extended)
@tf_contextlib.contextmanager
def experimental_variable_partitioning_scope(self):
"""A context manager for creating `ShardedVariable`.
Variables created inside a `with experimental_variable_partitioning_scope()`
code block will be of type `ShardedVariable` and their values are
partitioned among parameter servers along the first / outermost axis. The
number of shards are equal to the number of parameter servers.
Variables created within this scope must be initialized using a callable as
`initial_value` and a known shape.
Div partition strategy is used to partition the variable. Assuming we
assign consective integer ids along the first axis of the variable, then ids
are assigned to shards in a contiguous manner, while attempting to keep each
shard size identical. If the ids do not evenly divide the number of shards,
each of the first several shards will be assigned one more id. For instance,
a variable whose first dimension is 13 has 13 ids, and they are split across
5 shards as: `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
Yields:
A context manager for creating `ShardedVariable`.
"""
with variable_scope.variable_creator_scope(
self._extended._make_sharded_variable_creator()):
yield
def _verify_args_and_config(self, cluster_resolver):
if not cluster_resolver.cluster_spec():
raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
if self.extended._num_gpus_per_worker > 1:
raise NotImplementedError("Multi-gpu is not supported yet.")
class ParameterServerStrategyV2Extended(
parameter_server_strategy.ParameterServerStrategyExtended):
"""Extended class for ParameterServerStrategyV2.
Please see `tf.distribute.StrategyExtended` doc for more information.
"""
def __init__(self, container_strategy, cluster_resolver):
"""Initialization of ParameterServerStrategyV2Extended."""
super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
self._variable_count = 0
def _create_variable(self, next_creator, **kwargs):
if "colocate_with" in kwargs:
colocate_with = kwargs["colocate_with"]
# Clear the variable scope to avoid possible conflicts between device
# scope and colocation scope.
with ops.device(None):
with ops.colocate_with(colocate_with):
var = next_creator(**kwargs)
logging.debug(
"Creating variable (name:%s, shape:%r) that colocates with %s",
var.name, var.shape, kwargs["colocate_with"].name)
return var
# Clear the colocation scope to avoid possible conflicts between device
# scope and colocation scope.
with ops.colocate_with(None, ignore_existing=True):
with ops.device("/job:ps/task:%d" %
(self._variable_count % self._num_ps)):
var = next_creator(**kwargs)
logging.debug(
"Creating variable (name:%s, shape:%r) on /job:ps/task:%d",
var.name, var.shape, (self._variable_count % self._num_ps))
self._variable_count += 1
return var
def _make_sharded_variable_creator(self):
"""Returns a function conforming to the `variable_creator` signature.
The returned function creates `ShardedVariable` when called.
"""
def sharded_variable_creator(next_creator, **kwargs):
if "shape" not in kwargs or kwargs["shape"] is None:
raise ValueError("shape must be explicitly specified when creating "
"sharded variables")
init_fn = kwargs.get("initial_value", None)
# We intentionally don't allow non-callable initial_value to ensure the
# value is created on PS but not client. If the value is created on
# client, it will needed to be sent to PS for variable initialization,
# which is inefficient and can potentially hit the 2GB limit on protobuf
# serialization.
if init_fn is None or not callable(init_fn):
raise ValueError("initial_value must be specified as a callable when "
"creating sharded variables")
# Use "div" partition strategy to partition the variable.
full_shape = kwargs["shape"]
if self._num_ps < full_shape[0]:
num_shards = self._num_ps
else:
num_shards = full_shape[0]
offsets = []
base = full_shape[0] // num_shards
extra = full_shape[0] % num_shards
for i in range(num_shards):
if i == 0:
offsets.append(0)
else:
prev_shard_size = base + (1 if i - 1 < extra else 0)
offsets.append(offsets[i - 1] + prev_shard_size)
# Note: The way we initialize sharded variables is suboptimal, as it
# needs to create the full value tensor separately on each PS which the
# variable is going to be placed on. The full value could be very large
# and consume a lot of memory. The ideal way is to only create what's
# needed on the shard, however that's not practical because:
# 1. Initializers don't have sharded behavior support, even though some
# initializers (e.g, uniform) can be used directly.
# 2. tf.Variable signature requires "initial_value" to be either a value
# or a callable without arguments, meaning it is not straightforward
# to make the sharded component from it.
def init_shard_fn(shard_index):
full_value = init_fn()
if shard_index < num_shards - 1:
return full_value[offsets[shard_index]:offsets[shard_index + 1]]
else:
return full_value[offsets[shard_index]:]
var_list = []
for i in range(num_shards):
kwargs["shape"] = None
kwargs["initial_value"] = lambda: init_shard_fn(i)
var_list.append(next_creator(**kwargs))
result = sharded_variable.ShardedVariable(var_list)
return result
return sharded_variable_creator
def _call_for_each_replica(self, fn, args, kwargs):
# TODO(rchao): Consider implementing sync PS training.
raise NotImplementedError("Sync PS training is not implemented yet.")

View File

@ -150,6 +150,9 @@ COMMON_PIP_DEPS = [
"//tensorflow/tools/docs:generate_lib",
"//tensorflow/tools/docs:parser",
"//tensorflow/tools/docs:py_guide_parser",
"//tensorflow/python/distribute/client:client",
"//tensorflow/python/distribute/client:parameter_server_client",
"//tensorflow/python/distribute/client:metric_utils",
]
# On Windows, python binary is a zip file of runfiles tree.