PSv2: Move TF2 parameter server training main library code into OSS.
PiperOrigin-RevId: 324103111 Change-Id: Ic6c9eb64e851873c1ee6e768f645b9367528470d
This commit is contained in:
parent
b92c68721b
commit
a4725be4a7
@ -1760,3 +1760,21 @@ distribute_py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "parameter_server_strategy_v2",
|
||||
srcs = ["parameter_server_strategy_v2.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":parameter_server_strategy",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:input_lib",
|
||||
"//tensorflow/python/distribute:sharded_variable",
|
||||
"//tensorflow/python/distribute:values",
|
||||
],
|
||||
)
|
||||
|
102
tensorflow/python/distribute/client/BUILD
Normal file
102
tensorflow/python/distribute/client/BUILD
Normal file
@ -0,0 +1,102 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "parameter_server_client",
|
||||
srcs = ["parameter_server_client.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":client",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "client",
|
||||
srcs = ["client.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":metric_utils",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:func_graph",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/distribute:input_lib",
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:executor",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/eager:remote",
|
||||
"@absl_py//absl/logging",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "client_test",
|
||||
size = "small",
|
||||
srcs = ["client_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 12,
|
||||
deps = [
|
||||
":client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:training_lib",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"@absl_py//absl/logging",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "parameter_server_client_test",
|
||||
srcs = ["parameter_server_client_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 14,
|
||||
tags = ["no_oss"], # TODO(b/162119374)
|
||||
deps = [
|
||||
":parameter_server_client",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:init_ops_v2",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/distribute:sharded_variable",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "metric_utils",
|
||||
srcs = ["metric_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python/eager:monitoring",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metric_utils_test",
|
||||
srcs = ["metric_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":client",
|
||||
":metric_utils",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
1197
tensorflow/python/distribute/client/client.py
Normal file
1197
tensorflow/python/distribute/client/client.py
Normal file
File diff suppressed because it is too large
Load Diff
388
tensorflow/python/distribute/client/client_test.py
Normal file
388
tensorflow/python/distribute/client/client_test.py
Normal file
@ -0,0 +1,388 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for client.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import threading
|
||||
import time
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.python.distribute.client import client
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class CoordinatedClosureQueueTest(test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
queue = client._CoordinatedClosureQueue()
|
||||
closure1 = self._create_closure()
|
||||
queue.put(closure1)
|
||||
self.assertIs(closure1, queue.get())
|
||||
self.assertFalse(queue.done())
|
||||
queue.put_back(closure1)
|
||||
self.assertEqual(closure1, queue.get())
|
||||
queue.mark_finished(closure1)
|
||||
self.assertTrue(queue.done())
|
||||
queue.wait()
|
||||
|
||||
def testProcessAtLeaseOnce(self):
|
||||
closure_queue = client._CoordinatedClosureQueue()
|
||||
labels = ['A', 'B', 'C', 'D', 'E']
|
||||
processed_count = collections.defaultdict(int)
|
||||
|
||||
coord = coordinator.Coordinator(clean_stop_exception_types=[])
|
||||
|
||||
def process_queue():
|
||||
with coord.stop_on_exception():
|
||||
has_been_put_back = False
|
||||
while True:
|
||||
closure = closure_queue.get(timeout=30)
|
||||
if closure is None:
|
||||
break
|
||||
if not has_been_put_back:
|
||||
has_been_put_back = True
|
||||
closure_queue.put_back(closure)
|
||||
continue
|
||||
closure._function()
|
||||
closure_queue.mark_finished(closure)
|
||||
|
||||
def get_func(label):
|
||||
|
||||
def func():
|
||||
logging.info('Label: %s, before waiting 3 sec', label)
|
||||
time.sleep(3)
|
||||
processed_count[label] += 1
|
||||
logging.info('Label: %s, after waiting 3 sec', label)
|
||||
|
||||
return func
|
||||
|
||||
for label in labels:
|
||||
closure_queue.put(client.Closure(get_func(label)))
|
||||
t1 = threading.Thread(target=process_queue, daemon=True)
|
||||
t1.start()
|
||||
t2 = threading.Thread(target=process_queue, daemon=True)
|
||||
t2.start()
|
||||
|
||||
# Make sure multiple wait() calls are fine.
|
||||
closure_queue.wait()
|
||||
closure_queue.wait()
|
||||
closure_queue.wait()
|
||||
closure_queue.wait()
|
||||
|
||||
self.assertEqual(processed_count, collections.Counter(labels))
|
||||
|
||||
coord.join([t1, t2])
|
||||
|
||||
def testNotifyBeforeWait(self):
|
||||
closure_queue = client._CoordinatedClosureQueue()
|
||||
|
||||
def func():
|
||||
logging.info('func running')
|
||||
|
||||
coord = coordinator.Coordinator(clean_stop_exception_types=[])
|
||||
|
||||
def process_queue():
|
||||
with coord.stop_on_exception():
|
||||
closure = closure_queue.get()
|
||||
closure_queue.mark_finished(closure)
|
||||
|
||||
closure_queue.put(client.Closure(func))
|
||||
t = threading.Thread(target=process_queue)
|
||||
t.start()
|
||||
coord.join([t])
|
||||
|
||||
# This test asserts that waiting at the time the function has been processed
|
||||
# doesn't time out.
|
||||
closure_queue.wait()
|
||||
|
||||
def testWaitRaiseErrorAfterMarkFailure(self):
|
||||
closure_queue = client._CoordinatedClosureQueue()
|
||||
closure_queue.put(self._create_closure())
|
||||
closure = closure_queue.get()
|
||||
|
||||
wait_finish_event = threading.Event()
|
||||
coord = coordinator.Coordinator(clean_stop_exception_types=[])
|
||||
|
||||
# Using a thread to verify that closure_queue.wait() will not return until
|
||||
# all inflight closures are finished.
|
||||
|
||||
def mark_finished_fn():
|
||||
with coord.stop_on_exception():
|
||||
self.assertFalse(wait_finish_event.is_set())
|
||||
try:
|
||||
raise ValueError('Some error.')
|
||||
except ValueError as e:
|
||||
closure_queue.mark_failed(e, closure)
|
||||
wait_finish_event.wait()
|
||||
|
||||
t = threading.Thread(target=mark_finished_fn)
|
||||
t.start()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.wait()
|
||||
wait_finish_event.set()
|
||||
|
||||
coord.join([t])
|
||||
self.assertTrue(closure_queue.done())
|
||||
|
||||
def _create_closure(self):
|
||||
|
||||
@def_function.function()
|
||||
def some_function():
|
||||
return 1.0
|
||||
|
||||
return client.Closure(some_function)
|
||||
|
||||
def _put_two_closures_and_get_one(self):
|
||||
closure_queue = client._CoordinatedClosureQueue()
|
||||
closure1 = self._create_closure()
|
||||
closure_queue.put(closure1)
|
||||
|
||||
closure2 = self._create_closure()
|
||||
closure_queue.put(closure2)
|
||||
|
||||
closure_got = closure_queue.get() # returns closure1
|
||||
self.assertIs(closure_got, closure1)
|
||||
self.assertIsNot(closure_got, closure2)
|
||||
return closure_queue, closure1, closure2
|
||||
|
||||
def testPutRaiseError(self):
|
||||
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
|
||||
|
||||
closure_queue.mark_failed(ValueError(), closure1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.put(self._create_closure())
|
||||
|
||||
self.assertTrue(closure_queue.done())
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
client.FunctionRetryableError,
|
||||
'The corresponding function is cancelled. Please reschedule the '
|
||||
'function.'):
|
||||
closure2._fetch_output_remote_values()
|
||||
|
||||
# The error is cleared.
|
||||
closure_queue.put(self._create_closure())
|
||||
|
||||
def testWaitRaiseError(self):
|
||||
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
|
||||
|
||||
closure_queue.mark_failed(ValueError(), closure1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.wait()
|
||||
self.assertTrue(closure_queue.done())
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
client.FunctionRetryableError,
|
||||
'The corresponding function is cancelled. Please reschedule the '
|
||||
'function.'):
|
||||
closure2._fetch_output_remote_values()
|
||||
|
||||
# The error is cleared.
|
||||
closure_queue.wait()
|
||||
|
||||
def testDoneRaiseError(self):
|
||||
closure_queue, closure1, _ = self._put_two_closures_and_get_one()
|
||||
closure_queue.get()
|
||||
|
||||
self.assertFalse(closure_queue.done())
|
||||
closure_queue.mark_failed(ValueError(), closure1)
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.done()
|
||||
|
||||
def _test_error_reporting_and_cancel_flow(self, call_wait):
|
||||
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
|
||||
closure_queue.put(self._create_closure())
|
||||
closure_queue.get()
|
||||
# At this moment, there are two inflight, one in queue.
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 2)
|
||||
|
||||
# Simulating closure1 fails.
|
||||
try:
|
||||
raise ValueError('Some error.')
|
||||
except ValueError as e:
|
||||
nest.map_structure(lambda x: x._set_error(e),
|
||||
closure1._output_remote_values)
|
||||
self.assertEqual(closure_queue._error_generation, 0) # pylint: disable=g-assert-in-except
|
||||
closure_queue.mark_failed(e, closure1)
|
||||
self.assertEqual(closure_queue._error_generation, 1)
|
||||
# At this moment, there are one inflight, nothing
|
||||
# in queue (because the ones in queue should have been removed and
|
||||
# cancelled).
|
||||
self.assertTrue(closure_queue._queue.empty())
|
||||
# Doesn't include out of generation closures.
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
|
||||
coord = coordinator.Coordinator(clean_stop_exception_types=[])
|
||||
closure3 = self._create_closure()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Verifying `wait()` or `put()` raises even if one closure is in
|
||||
# flight.
|
||||
if call_wait:
|
||||
closure_queue.wait()
|
||||
else:
|
||||
closure_queue.put(closure3)
|
||||
# At this moment, there is one inflight, nothing in queue.
|
||||
self.assertTrue(closure_queue._queue.empty())
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
|
||||
# This asserts that closure1 has errored.
|
||||
with self.assertRaisesRegex(ValueError, 'Some error.'):
|
||||
closure1._fetch_output_remote_values()
|
||||
|
||||
# The following asserts that closure3 should have been cancelled.
|
||||
if not call_wait:
|
||||
with self.assertRaisesRegex(
|
||||
client.FunctionRetryableError,
|
||||
'The corresponding function is cancelled. Please reschedule the '
|
||||
'function.'):
|
||||
closure3._fetch_output_remote_values()
|
||||
|
||||
# Closure2 is inflight, so it shouldn't be ready.
|
||||
self.assertEqual(closure2._output_remote_values._status,
|
||||
client._RemoteValueStatus.NOT_READY)
|
||||
|
||||
# And `wait` should block because closure2 is not back yet.
|
||||
self.assertFalse(closure_queue.wait(timeout=20))
|
||||
|
||||
# Now let's assume that closure2 isn't successful due to worker preemption,
|
||||
# and now it's attempted to be put back, but ends up getting cancelled.
|
||||
self.assertEqual(closure2._error_generation, 0)
|
||||
self.assertEqual(closure_queue._error_generation, 1)
|
||||
closure_queue.put_back(closure2)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
client.FunctionRetryableError,
|
||||
'The corresponding function is cancelled. Please reschedule the '
|
||||
'function.'):
|
||||
closure2._fetch_output_remote_values()
|
||||
|
||||
# At this moment, there is nothing inflight, and the queue is also empty
|
||||
# (because closure2 should not be added back to the queue).
|
||||
self.assertTrue(closure_queue._queue.empty())
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 0)
|
||||
|
||||
closure4 = self._create_closure()
|
||||
|
||||
e = threading.Event()
|
||||
|
||||
def get_fn():
|
||||
with coord.stop_on_exception():
|
||||
# This should end up getting closure4, not closure2, because closure2
|
||||
# has been cancelled and should not be got.
|
||||
closure_got = closure_queue.get()
|
||||
e.set()
|
||||
self.assertEqual(closure_got._error_generation, 1)
|
||||
self.assertEqual(closure_queue._error_generation, 1)
|
||||
self.assertIs(closure4, closure_got)
|
||||
self.assertIsNot(closure2, closure_got)
|
||||
|
||||
t = threading.Thread(target=get_fn)
|
||||
t.start()
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
# Make sure `closure_got = closure_queue.get()` is unblocked as a result of
|
||||
# `closure_queue.put(closure4)`.
|
||||
self.assertFalse(e.is_set())
|
||||
closure_queue.put(closure4)
|
||||
self.assertTrue(e.wait())
|
||||
coord.join([t])
|
||||
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
closure_queue.mark_finished(closure4)
|
||||
# The queue is now cleared and nothing inflight.
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 0)
|
||||
closure_queue.wait()
|
||||
|
||||
def testWaitRaiseErrorAfterAnErrorIsReported(self):
|
||||
self._test_error_reporting_and_cancel_flow(call_wait=True)
|
||||
|
||||
def testPutRaiseErrorAfterAnErrorIsReported(self):
|
||||
self._test_error_reporting_and_cancel_flow(call_wait=False)
|
||||
|
||||
def testStateIsRestoredAfterJoinIsCalled(self):
|
||||
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
|
||||
closure_queue.get()
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 2)
|
||||
closure_queue.mark_failed(ValueError('test error'), closure1)
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.put(self._create_closure())
|
||||
closure_queue.mark_failed(ValueError('test error'), closure2)
|
||||
|
||||
# closure2's error is previous generation so should not raise at this
|
||||
# following put, and _error should have been cleared.
|
||||
self.assertIsNone(closure_queue._error)
|
||||
closure_queue.put(self._create_closure())
|
||||
self.assertIsNone(closure_queue._error)
|
||||
|
||||
def testStateIsRestoredAfterJoinIsCalled_WaitShouldReturn(self):
|
||||
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
|
||||
closure_queue.put(self._create_closure())
|
||||
closure_queue.get() # got closure2
|
||||
self.assertFalse(closure_queue._queue.empty()) # still has closure3
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 2) # closure1,2
|
||||
closure_queue.mark_failed(ValueError('test error'), closure1)
|
||||
self.assertTrue(closure_queue._queue.empty()) # closure3 cancelled
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
with self.assertRaises(ValueError):
|
||||
closure_queue.wait() # reports error from closure1
|
||||
|
||||
# `wait` should block because closure2 is not back yet, even if closure2
|
||||
# was sent inflight before the error.
|
||||
self.assertFalse(closure_queue.wait(timeout=20))
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 1)
|
||||
closure_queue.mark_finished(closure2)
|
||||
closure_queue.wait() # wait should pass immediately
|
||||
self.assertEqual(closure_queue._inflight_closure_count, 0)
|
||||
|
||||
def testThreadSafey(self):
|
||||
thread_count = 10
|
||||
queue = client._CoordinatedClosureQueue()
|
||||
|
||||
# Each thread performs 20 queue actions: 10 are `put_back` and 10 are
|
||||
# `mark_finished`.
|
||||
action_count = 20
|
||||
|
||||
def func():
|
||||
for i in range(action_count):
|
||||
closure = queue.get()
|
||||
if i % 2 == 0:
|
||||
queue.put_back(closure)
|
||||
else:
|
||||
queue.mark_finished(closure)
|
||||
|
||||
threads = [threading.Thread(target=func) for i in range(thread_count)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for _ in range(thread_count * action_count // 2):
|
||||
queue.put(self._create_closure())
|
||||
queue.wait()
|
||||
self.assertTrue(queue.done())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
79
tensorflow/python/distribute/client/metric_utils.py
Normal file
79
tensorflow/python/distribute/client/metric_utils.py
Normal file
@ -0,0 +1,79 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Metrics collecting utilities for single client training."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
enable_metrics = False
|
||||
|
||||
# Time in seconds to bucket the distribution of execution time. Range from
|
||||
# 0.001s (i.e., 1ms) to 1000s.
|
||||
_time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6)
|
||||
|
||||
_function_tracing_sampler = monitoring.Sampler(
|
||||
'/tensorflow/api/ps_strategy/client/function_tracing', _time_buckets,
|
||||
'Sampler to track the time (in seconds) for tracing functions.')
|
||||
|
||||
_closure_execution_sampler = monitoring.Sampler(
|
||||
'/tensorflow/api/ps_strategy/client/closure_execution', _time_buckets,
|
||||
'Sampler to track the time (in seconds) for executing closures.')
|
||||
|
||||
_remote_value_fetch_sampler = monitoring.Sampler(
|
||||
'/tensorflow/api/ps_strategy/client/remote_value_fetch', _time_buckets,
|
||||
'Sampler to track the time (in seconds) for fetching remote_value.')
|
||||
|
||||
_METRICS_MAPPING = {
|
||||
'function_tracing': _function_tracing_sampler,
|
||||
'closure_execution': _closure_execution_sampler,
|
||||
'remote_value_fetch': _remote_value_fetch_sampler
|
||||
}
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def monitored_timer(metric_name, state_tracker=None):
|
||||
"""Monitor the execution time and collect it into the specified metric."""
|
||||
if not enable_metrics:
|
||||
yield
|
||||
else:
|
||||
start_time = time.time()
|
||||
start_state = state_tracker() if state_tracker else None
|
||||
yield
|
||||
duration_sec = time.time() - start_time
|
||||
# If a state_checker is provided, record the metric only if the end state is
|
||||
# different from the start state.
|
||||
if state_tracker is None or state_tracker() != start_state:
|
||||
metric = _METRICS_MAPPING[metric_name]
|
||||
metric.get_cell().add(duration_sec)
|
||||
|
||||
|
||||
def get_metric_summary(metric_name):
|
||||
"""Get summary for the specified metric."""
|
||||
metric = _METRICS_MAPPING[metric_name]
|
||||
histogram_proto = metric.get_cell().value()
|
||||
ret = dict()
|
||||
ret['min'] = histogram_proto.min
|
||||
ret['max'] = histogram_proto.max
|
||||
ret['num'] = histogram_proto.num
|
||||
ret['sum'] = histogram_proto.sum
|
||||
# TODO(haoyuzhang): consider reporting the distribution in buckets.
|
||||
return ret
|
69
tensorflow/python/distribute/client/metric_utils_test.py
Normal file
69
tensorflow/python/distribute/client/metric_utils_test.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metrics collecting in client."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute.client import client
|
||||
from tensorflow.python.distribute.client import metric_utils
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
|
||||
class MetricUtilsTest(test.TestCase):
|
||||
|
||||
def testClientMetrics(self):
|
||||
metric_utils.enable_metrics = True
|
||||
|
||||
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
||||
num_workers=1, num_ps=1, rpc_layer='grpc')
|
||||
cluster_def['chief'] = [
|
||||
'localhost:%d' % multi_worker_test_base.pick_unused_port()
|
||||
]
|
||||
cluster_resolver = SimpleClusterResolver(
|
||||
ClusterSpec(cluster_def), rpc_layer='grpc')
|
||||
cluster = client.Cluster(cluster_resolver)
|
||||
|
||||
@def_function.function
|
||||
def func():
|
||||
time.sleep(0.5)
|
||||
return 3
|
||||
|
||||
result = cluster.schedule(func, args=None, kwargs=None)
|
||||
result = cluster.schedule(func, args=None, kwargs=None)
|
||||
cluster.join()
|
||||
self.assertEqual(result._get_value().numpy(), 3)
|
||||
|
||||
# Tracing, closure execution, and remote_value fetching should be executed
|
||||
# exactly once for running this function.
|
||||
metric_tracing = metric_utils.get_metric_summary('function_tracing')
|
||||
self.assertEqual(metric_tracing['num'], 1)
|
||||
# Tracing time should be longer than the sleep time in Python function.
|
||||
self.assertGreater(metric_tracing['sum'], 0.5)
|
||||
metric_closure = metric_utils.get_metric_summary('closure_execution')
|
||||
self.assertEqual(metric_closure['num'], 2)
|
||||
metric_remote_value = metric_utils.get_metric_summary('remote_value_fetch')
|
||||
self.assertEqual(metric_remote_value['num'], 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -0,0 +1,55 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Parameter server client module.
|
||||
|
||||
This is currently under development and the API is subject to change.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute.client import client
|
||||
|
||||
|
||||
class ParameterServerClient(client.Client):
|
||||
"""A client that uses `ParameterServerStrategy` to distribute tasks.
|
||||
|
||||
Parameter server training refers to the distributed training architecture
|
||||
that requires two jobs in the cluster: workers and parameter servers. The
|
||||
variables and updates to those variables are assigned on the parameter
|
||||
servers' tasks, and the actual computation intensive operations are assigned
|
||||
on worker tasks. In TF2, parameter server training only starts up one
|
||||
client process, to drive and coordinate the workers and parameter servers.
|
||||
This is referred to as single-client architecture, as opposed to multi-client
|
||||
approach which is seen more often in traditional TensorFlow distributed
|
||||
training, including `tf.estimator.Estimator` and `tf.keras` with
|
||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||
|
||||
`ParameterServerClient` is a `Client` that uses `ParameterServerStrategy` as
|
||||
the underlying strategy to distribute, and is the starting point of parameter
|
||||
server training/evaluation.
|
||||
|
||||
If 'TF_CONFIG' environment variable is used, provide a
|
||||
`TFConfigClusterResolver` to detect configurations for multi-worker training.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cluster_resolver):
|
||||
super(ParameterServerClient, self).__init__(
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||
cluster_resolver))
|
@ -0,0 +1,405 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for parameter_server_client.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import logging
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.distribute.client import client
|
||||
from tensorflow.python.distribute.client import parameter_server_client
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
|
||||
def make_client(num_workers, num_ps):
|
||||
# TODO(rchao): Test the internal rpc_layer version.
|
||||
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
||||
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
|
||||
cluster_def["chief"] = [
|
||||
"localhost:%d" % multi_worker_test_base.pick_unused_port()
|
||||
]
|
||||
cluster_resolver = SimpleClusterResolver(
|
||||
ClusterSpec(cluster_def), rpc_layer="grpc")
|
||||
return parameter_server_client.ParameterServerClient(cluster_resolver)
|
||||
|
||||
|
||||
class ParameterServerClientTest(test.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(ParameterServerClientTest, cls).setUpClass()
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
|
||||
def testBasic(self):
|
||||
self.client._strategy.extended._variable_count = 0
|
||||
with self.client.context():
|
||||
v1 = variables.Variable(initial_value=0.0)
|
||||
v2 = variables.Variable(initial_value=1.0)
|
||||
self.assertEqual(self.client._strategy.extended._variable_count, 2)
|
||||
|
||||
@def_function.function
|
||||
def worker_fn():
|
||||
v1.assign_add(0.1)
|
||||
v2.assign_sub(0.2)
|
||||
return v1.read_value() / v2.read_value()
|
||||
|
||||
results = self.client.schedule(worker_fn)
|
||||
logging.info("Results of experimental_run_v2: %f",
|
||||
self.client.fetch(results))
|
||||
|
||||
self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6)
|
||||
self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6)
|
||||
|
||||
def testFnReturnNestedValues(self):
|
||||
x = constant_op.constant(1)
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
return x + 1, (x + 2, x + 3), [x + 4], {"v": x}
|
||||
|
||||
got = self.client.schedule(f)
|
||||
want = 2, (3, 4), [5], {"v": 1}
|
||||
self.assertEqual(self.client.fetch(got), want)
|
||||
|
||||
def testInputFunction(self):
|
||||
|
||||
def input_fn():
|
||||
return dataset_ops.DatasetV2.range(1, 2)
|
||||
|
||||
with self.client.context():
|
||||
v = variables.Variable(initial_value=0, dtype=dtypes.int64)
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(iterator):
|
||||
x = next(iterator)
|
||||
v.assign_add(x)
|
||||
return x
|
||||
|
||||
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
|
||||
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
|
||||
result = self.client.fetch(result)
|
||||
self.assertEqual(result, (1,))
|
||||
result = self.client.schedule(worker_fn, args=(iter(distributed_dataset),))
|
||||
result = self.client.fetch(result)
|
||||
self.assertEqual(result, (1,))
|
||||
|
||||
self.assertAlmostEqual(v.read_value().numpy(), 2, delta=1e-6)
|
||||
|
||||
def testAsyncScheduleAndJoin(self):
|
||||
|
||||
def input_fn():
|
||||
return dataset_ops.DatasetV2.from_tensor_slices([2] * 10)
|
||||
|
||||
with self.client.context():
|
||||
v = variables.Variable(initial_value=0, dtype=dtypes.int32)
|
||||
|
||||
# TODO(yuefengz): the following tf.function has a return value which is None
|
||||
# in its structured_outputs.
|
||||
@def_function.function
|
||||
def worker_fn(iterator):
|
||||
x = next(iterator)
|
||||
v.assign_add(x)
|
||||
|
||||
distributed_dataset = self.client.create_per_worker_dataset(input_fn)
|
||||
|
||||
iterator = iter(distributed_dataset)
|
||||
|
||||
# Verifying joining without any scheduling doesn't hang.
|
||||
self.client.join()
|
||||
self.assertEqual(v.read_value().numpy(), 0)
|
||||
|
||||
for _ in range(5):
|
||||
self.client.schedule(worker_fn, args=(iterator,))
|
||||
self.client.join()
|
||||
|
||||
# With 5 addition it should be 2*5 = 10.
|
||||
self.assertEqual(v.read_value().numpy(), 10)
|
||||
|
||||
for _ in range(5):
|
||||
self.client.schedule(worker_fn, args=(iterator,))
|
||||
|
||||
# Verifying multiple join is fine.
|
||||
self.client.join()
|
||||
self.client.join()
|
||||
self.client.join()
|
||||
|
||||
self.assertTrue(self.client.done())
|
||||
|
||||
# Likewise, it's now 20.
|
||||
self.assertEqual(v.read_value().numpy(), 20)
|
||||
|
||||
def testInputFunctionWithMap(self):
|
||||
self._map_fn_tracing_count = 0
|
||||
|
||||
def input_fn():
|
||||
def map_fn(x):
|
||||
self._map_fn_tracing_count += 1
|
||||
return x + 10
|
||||
return dataset_ops.DatasetV2.range(0, 10).map(map_fn)
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(iterator):
|
||||
return next(iterator)
|
||||
|
||||
distributed_dataset = (
|
||||
self.client.create_per_worker_dataset(input_fn))
|
||||
result = self.client.schedule(
|
||||
worker_fn, args=(iter(distributed_dataset),))
|
||||
self.assertEqual(result.fetch(), (10,))
|
||||
self.assertEqual(self._map_fn_tracing_count, 1)
|
||||
|
||||
def testInputFunctionCreateVariables(self):
|
||||
|
||||
def input_fn():
|
||||
v = variables.Variable(initial_value=0.0)
|
||||
return v.read_value()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.client.create_per_worker_dataset(input_fn)
|
||||
|
||||
|
||||
class LimitedClosureQueueSizeBasicTest(ParameterServerClientTest):
|
||||
"""Test basic functionality works with explicit maximum closure queue size.
|
||||
|
||||
Execute the same set of test cases as in ParameterServerClientTest, with an
|
||||
explicit size limit for the closure queue. Note that even when the queue size
|
||||
is set to infinite, there is still a maximum practical size (depends on host
|
||||
memory limit) that might cause the queue.put operations to be blocking when
|
||||
scheduling a large number of closures on a big cluster. These tests make sure
|
||||
that the client does not run into deadlocks in such scenario.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
|
||||
client._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
|
||||
|
||||
class VariablePartitioningScopeTest(test.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(VariablePartitioningScopeTest, cls).setUpClass()
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
|
||||
def testBasic(self):
|
||||
with self.client.context():
|
||||
with self.client.experimental_variable_partitioning_scope():
|
||||
init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
v1 = variables.Variable(
|
||||
initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64),
|
||||
shape=(5, 2),
|
||||
dtype=dtypes.int64)
|
||||
|
||||
init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5])
|
||||
v2 = variables.Variable(
|
||||
initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64),
|
||||
shape=(6, 1),
|
||||
dtype=dtypes.int64)
|
||||
|
||||
self.assertIsInstance(v1, sharded_variable.ShardedVariable)
|
||||
self.assertLen(v1.variables, 2)
|
||||
self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0")
|
||||
self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1")
|
||||
self.assertAllEqual(v1.variables[0].read_value().numpy(),
|
||||
[[0, 1], [2, 3], [4, 5]])
|
||||
self.assertAllEqual(v1.variables[1].read_value().numpy(), [[6, 7], [8, 9]])
|
||||
|
||||
self.assertIsInstance(v2, sharded_variable.ShardedVariable)
|
||||
self.assertLen(v2.variables, 2)
|
||||
self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0")
|
||||
self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1")
|
||||
self.assertAllEqual(v2.variables[0].read_value().numpy(), [[0], [1], [2]])
|
||||
self.assertAllEqual(v2.variables[1].read_value().numpy(), [[3], [4], [5]])
|
||||
|
||||
def testSurplusPS(self):
|
||||
with self.client.context():
|
||||
with self.client.experimental_variable_partitioning_scope():
|
||||
initializer = init_ops_v2.Constant([0])
|
||||
|
||||
v = variables.Variable(
|
||||
initial_value=lambda: initializer(shape=(1,), dtype=dtypes.int64),
|
||||
shape=(1,),
|
||||
dtype=dtypes.int64)
|
||||
|
||||
self.assertIsInstance(v, sharded_variable.ShardedVariable)
|
||||
self.assertLen(v.variables, 1)
|
||||
self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
|
||||
self.assertAllEqual(v.variables[0].read_value().numpy(), [0])
|
||||
|
||||
def testInvalidArgument(self):
|
||||
with self.assertRaisesRegex(ValueError, "initial_value"):
|
||||
with self.client.experimental_variable_partitioning_scope():
|
||||
variables.Variable(initial_value=[0, 1, 2], shape=(3,))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "shape"):
|
||||
with self.client.experimental_variable_partitioning_scope():
|
||||
initializer = init_ops_v2.Constant([0, 1, 2])
|
||||
variables.Variable(
|
||||
initial_value=lambda: initializer(shape=(3,), dtype=dtypes.int64),
|
||||
dtype=dtypes.int64)
|
||||
|
||||
def testPerWorkerValue(self):
|
||||
var_shape = tuple()
|
||||
var_dtype = dtypes.float32
|
||||
var_name = "var"
|
||||
|
||||
def create_var():
|
||||
var = variables.Variable(
|
||||
initial_value=0.0, dtype=var_dtype, name=var_name)
|
||||
self.assertIn("worker", var.device)
|
||||
return var
|
||||
|
||||
worker_local_var = self.client._create_per_worker_resources(create_var)
|
||||
|
||||
# The following is a workaround to allow `worker_local_var` to be passed in
|
||||
# as args to the `client.schedule` method which requires tensor specs to
|
||||
# trace tf.function but _create_worker_resources' return values don't have
|
||||
# tensor specs. We can get rid of this workaround once
|
||||
# _create_worker_resources is able to infer the tensor spec of the return
|
||||
# value of the function passed in. See b/154675763.
|
||||
for var in worker_local_var._values:
|
||||
var._set_type_spec(tensor_spec.TensorSpec(var_shape, var_dtype, var_name))
|
||||
|
||||
def worker_fn(var):
|
||||
var.assign_add(1.0)
|
||||
|
||||
for _ in range(10):
|
||||
# Which slice of `worker_local_var` will be used will depend on which
|
||||
# worker the `worker_fn` gets scheduled on.
|
||||
self.client.schedule(worker_fn, args=(worker_local_var,))
|
||||
self.client.join()
|
||||
|
||||
var_sum = sum(self.client.fetch(worker_local_var._values))
|
||||
self.assertEqual(var_sum, 10.0)
|
||||
|
||||
|
||||
class ErrorReportingTest(test.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(ErrorReportingTest, cls).setUpClass()
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
|
||||
with cls.client.context():
|
||||
cls.iteration = variables.Variable(initial_value=0.0)
|
||||
|
||||
@def_function.function
|
||||
def _normal_function(self):
|
||||
x = random_ops.random_uniform((2, 10))
|
||||
y = random_ops.random_uniform((10, 2))
|
||||
self.iteration.assign_add(1.0)
|
||||
return math_ops.reduce_mean(math_ops.matmul(x, y))
|
||||
|
||||
@def_function.function
|
||||
def _error_function(self):
|
||||
x = random_ops.random_uniform((2, 10))
|
||||
y = random_ops.random_uniform((10, 2))
|
||||
check_ops.assert_non_positive_v2(math_ops.reduce_sum(math_ops.matmul(x, y)))
|
||||
self.iteration.assign_add(1.0)
|
||||
return self.iteration
|
||||
|
||||
def testJoinRaiseError(self):
|
||||
for _ in range(3):
|
||||
self.client.schedule(self._normal_function)
|
||||
self.client.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
|
||||
def testScheduleRaiseError(self):
|
||||
for _ in range(3):
|
||||
self.client.schedule(self._normal_function)
|
||||
self.client.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
while True:
|
||||
self.client.schedule(self._normal_function)
|
||||
|
||||
def testErrorWillbeCleared(self):
|
||||
self.skipTest("b/157597579")
|
||||
self.client.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
|
||||
for _ in range(3):
|
||||
self.client.schedule(self._normal_function)
|
||||
self.client.schedule(self._error_function)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
|
||||
def testFutureReturnError(self):
|
||||
result = self.client.schedule(self._error_function)
|
||||
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
result.fetch()
|
||||
|
||||
# Clear the error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
|
||||
def testInputError(self):
|
||||
aborted = self.client.schedule(self._error_function)
|
||||
|
||||
@def_function.function
|
||||
def func(x):
|
||||
return x + 1.0
|
||||
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.client.join()
|
||||
|
||||
result = self.client.schedule(func, args=(aborted,))
|
||||
with self.assertRaises(client.InputError):
|
||||
result.fetch()
|
||||
|
||||
with self.assertRaises(client.InputError):
|
||||
self.client.join()
|
||||
|
||||
|
||||
class LimitedClosureQueueErrorTest(ErrorReportingTest):
|
||||
"""Test error reporting works with explicit maximum closure queue size.
|
||||
|
||||
Execute the same set of test cases as in ErrorReportingTest, with an explicit
|
||||
size limit for the closure queue.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(LimitedClosureQueueErrorTest, cls).setUpClass()
|
||||
client._CLOSURE_QUEUE_MAX_SIZE = 2
|
||||
cls.client = make_client(num_workers=3, num_ps=2)
|
||||
|
||||
with cls.client.context():
|
||||
cls.iteration = variables.Variable(initial_value=0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
202
tensorflow/python/distribute/parameter_server_strategy_v2.py
Normal file
202
tensorflow/python/distribute/parameter_server_strategy_v2.py
Normal file
@ -0,0 +1,202 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Parameter server strategy V2 class.
|
||||
|
||||
This is currently under development and the API is subject to change.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import logging
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class ParameterServerStrategyV2(distribute_lib.Strategy):
|
||||
"""An asynchronous multi-worker parameter server tf.distribute strategy.
|
||||
|
||||
Currently, `ParameterServerStrategyV2` is not supported to be used as a
|
||||
standalone tf.distribute strategy. It must be used in conjunction with
|
||||
`Client`. The recommended way of using the combination is through a
|
||||
`ParameterServerClient` object. Please see `Client` and
|
||||
`ParameterServerClient` for more information.
|
||||
|
||||
This is currently under development, and the API as well as implementation
|
||||
is subject to changes.
|
||||
"""
|
||||
|
||||
def __init__(self, cluster_resolver):
|
||||
"""Initializes the V2 parameter server strategy.
|
||||
|
||||
Args:
|
||||
cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
|
||||
object.
|
||||
"""
|
||||
self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver)
|
||||
self._cluster_resolver = cluster_resolver
|
||||
self._verify_args_and_config(cluster_resolver)
|
||||
logging.info(
|
||||
"ParameterServerStrategyV2 is initialized with cluster_spec: "
|
||||
"%s", cluster_resolver.cluster_spec())
|
||||
super(ParameterServerStrategyV2, self).__init__(self._extended)
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def experimental_variable_partitioning_scope(self):
|
||||
"""A context manager for creating `ShardedVariable`.
|
||||
|
||||
Variables created inside a `with experimental_variable_partitioning_scope()`
|
||||
code block will be of type `ShardedVariable` and their values are
|
||||
partitioned among parameter servers along the first / outermost axis. The
|
||||
number of shards are equal to the number of parameter servers.
|
||||
|
||||
Variables created within this scope must be initialized using a callable as
|
||||
`initial_value` and a known shape.
|
||||
|
||||
Div partition strategy is used to partition the variable. Assuming we
|
||||
assign consective integer ids along the first axis of the variable, then ids
|
||||
are assigned to shards in a contiguous manner, while attempting to keep each
|
||||
shard size identical. If the ids do not evenly divide the number of shards,
|
||||
each of the first several shards will be assigned one more id. For instance,
|
||||
a variable whose first dimension is 13 has 13 ids, and they are split across
|
||||
5 shards as: `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
|
||||
|
||||
Yields:
|
||||
A context manager for creating `ShardedVariable`.
|
||||
"""
|
||||
with variable_scope.variable_creator_scope(
|
||||
self._extended._make_sharded_variable_creator()):
|
||||
yield
|
||||
|
||||
def _verify_args_and_config(self, cluster_resolver):
|
||||
if not cluster_resolver.cluster_spec():
|
||||
raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
|
||||
if self.extended._num_gpus_per_worker > 1:
|
||||
raise NotImplementedError("Multi-gpu is not supported yet.")
|
||||
|
||||
|
||||
class ParameterServerStrategyV2Extended(
|
||||
parameter_server_strategy.ParameterServerStrategyExtended):
|
||||
"""Extended class for ParameterServerStrategyV2.
|
||||
|
||||
Please see `tf.distribute.StrategyExtended` doc for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, container_strategy, cluster_resolver):
|
||||
"""Initialization of ParameterServerStrategyV2Extended."""
|
||||
super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
|
||||
self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
|
||||
self._variable_count = 0
|
||||
|
||||
def _create_variable(self, next_creator, **kwargs):
|
||||
|
||||
if "colocate_with" in kwargs:
|
||||
colocate_with = kwargs["colocate_with"]
|
||||
# Clear the variable scope to avoid possible conflicts between device
|
||||
# scope and colocation scope.
|
||||
with ops.device(None):
|
||||
with ops.colocate_with(colocate_with):
|
||||
var = next_creator(**kwargs)
|
||||
logging.debug(
|
||||
"Creating variable (name:%s, shape:%r) that colocates with %s",
|
||||
var.name, var.shape, kwargs["colocate_with"].name)
|
||||
return var
|
||||
|
||||
# Clear the colocation scope to avoid possible conflicts between device
|
||||
# scope and colocation scope.
|
||||
with ops.colocate_with(None, ignore_existing=True):
|
||||
with ops.device("/job:ps/task:%d" %
|
||||
(self._variable_count % self._num_ps)):
|
||||
var = next_creator(**kwargs)
|
||||
logging.debug(
|
||||
"Creating variable (name:%s, shape:%r) on /job:ps/task:%d",
|
||||
var.name, var.shape, (self._variable_count % self._num_ps))
|
||||
self._variable_count += 1
|
||||
return var
|
||||
|
||||
def _make_sharded_variable_creator(self):
|
||||
"""Returns a function conforming to the `variable_creator` signature.
|
||||
|
||||
The returned function creates `ShardedVariable` when called.
|
||||
"""
|
||||
|
||||
def sharded_variable_creator(next_creator, **kwargs):
|
||||
if "shape" not in kwargs or kwargs["shape"] is None:
|
||||
raise ValueError("shape must be explicitly specified when creating "
|
||||
"sharded variables")
|
||||
init_fn = kwargs.get("initial_value", None)
|
||||
# We intentionally don't allow non-callable initial_value to ensure the
|
||||
# value is created on PS but not client. If the value is created on
|
||||
# client, it will needed to be sent to PS for variable initialization,
|
||||
# which is inefficient and can potentially hit the 2GB limit on protobuf
|
||||
# serialization.
|
||||
if init_fn is None or not callable(init_fn):
|
||||
raise ValueError("initial_value must be specified as a callable when "
|
||||
"creating sharded variables")
|
||||
|
||||
# Use "div" partition strategy to partition the variable.
|
||||
full_shape = kwargs["shape"]
|
||||
if self._num_ps < full_shape[0]:
|
||||
num_shards = self._num_ps
|
||||
else:
|
||||
num_shards = full_shape[0]
|
||||
offsets = []
|
||||
base = full_shape[0] // num_shards
|
||||
extra = full_shape[0] % num_shards
|
||||
for i in range(num_shards):
|
||||
if i == 0:
|
||||
offsets.append(0)
|
||||
else:
|
||||
prev_shard_size = base + (1 if i - 1 < extra else 0)
|
||||
offsets.append(offsets[i - 1] + prev_shard_size)
|
||||
|
||||
# Note: The way we initialize sharded variables is suboptimal, as it
|
||||
# needs to create the full value tensor separately on each PS which the
|
||||
# variable is going to be placed on. The full value could be very large
|
||||
# and consume a lot of memory. The ideal way is to only create what's
|
||||
# needed on the shard, however that's not practical because:
|
||||
# 1. Initializers don't have sharded behavior support, even though some
|
||||
# initializers (e.g, uniform) can be used directly.
|
||||
# 2. tf.Variable signature requires "initial_value" to be either a value
|
||||
# or a callable without arguments, meaning it is not straightforward
|
||||
# to make the sharded component from it.
|
||||
def init_shard_fn(shard_index):
|
||||
full_value = init_fn()
|
||||
if shard_index < num_shards - 1:
|
||||
return full_value[offsets[shard_index]:offsets[shard_index + 1]]
|
||||
else:
|
||||
return full_value[offsets[shard_index]:]
|
||||
|
||||
var_list = []
|
||||
for i in range(num_shards):
|
||||
kwargs["shape"] = None
|
||||
kwargs["initial_value"] = lambda: init_shard_fn(i)
|
||||
var_list.append(next_creator(**kwargs))
|
||||
|
||||
result = sharded_variable.ShardedVariable(var_list)
|
||||
return result
|
||||
|
||||
return sharded_variable_creator
|
||||
|
||||
def _call_for_each_replica(self, fn, args, kwargs):
|
||||
# TODO(rchao): Consider implementing sync PS training.
|
||||
raise NotImplementedError("Sync PS training is not implemented yet.")
|
@ -150,6 +150,9 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/tools/docs:generate_lib",
|
||||
"//tensorflow/tools/docs:parser",
|
||||
"//tensorflow/tools/docs:py_guide_parser",
|
||||
"//tensorflow/python/distribute/client:client",
|
||||
"//tensorflow/python/distribute/client:parameter_server_client",
|
||||
"//tensorflow/python/distribute/client:metric_utils",
|
||||
]
|
||||
|
||||
# On Windows, python binary is a zip file of runfiles tree.
|
||||
|
Loading…
Reference in New Issue
Block a user