From cb2cf6f56fc538c935c422a32eb1fab530e170b3 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Wed, 29 Jul 2020 22:23:50 -0700 Subject: [PATCH] Add back the lock in CollectiveKeys MWMS is using CollectiveKeys directly when broadcasting variable initial values, so the lock in CollectiveAllReduce is not enough. This change also acquires in all methods in CollectiveKeys, instead of get_group_key(). PiperOrigin-RevId: 323938375 Change-Id: I15ea98ff62952d0c3bd4d33f74067b4bad03d7cb --- .../python/distribute/cross_device_utils.py | 40 ++++++++++++++----- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index 1d5c2c8f452..dc6dc4071bd 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import collections as pycoll +import copy +import threading from tensorflow.python.distribute import all_reduce from tensorflow.python.distribute import values as value_lib @@ -244,6 +246,8 @@ class CollectiveKeys(object): "Graph key": an integer key that is unique key graph. This is used to support multiple graphs per client session. It must be non-zero and set in the `config` argument of each call to `session.run`. + + This class is thread safe. """ def __init__(self, @@ -264,6 +268,7 @@ class CollectiveKeys(object): assert op_instance_key_start != variable_instance_key_start self._op_instance_key = op_instance_key_start self._variable_instance_key = variable_instance_key_start + self._lock = threading.Lock() def get_group_key(self, devices): """Returns a group key for the set of devices. @@ -282,23 +287,36 @@ class CollectiveKeys(object): # task_type and task_id. names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed]) key_id = ','.join(names) - if key_id not in self._group_key_table: - new_key = self._group_key - self._group_key += 1 - self._group_key_table[key_id] = new_key - return self._group_key_table[key_id] + with self._lock: + if key_id not in self._group_key_table: + new_key = self._group_key + self._group_key += 1 + self._group_key_table[key_id] = new_key + return self._group_key_table[key_id] def get_op_instance_key(self): """Returns a new instance key for use in defining a collective op.""" - v = self._op_instance_key - self._op_instance_key += 1 - return v + with self._lock: + v = self._op_instance_key + self._op_instance_key += 1 + return v def get_variable_instance_key(self): """Returns a new instance key for use in creating a Variable.""" - v = self._variable_instance_key - self._variable_instance_key += 1 - return v + with self._lock: + v = self._variable_instance_key + self._variable_instance_key += 1 + return v + + def __deepcopy__(self, memo): + # distribute_coordinator deep-copies the strategy object, so + # CollectiveKeys needs to support deep copy as well. + copied = CollectiveKeys() + copied._group_key = self._group_key + copied._group_key_table = copy.deepcopy(self._group_key_table, memo) + copied._op_instance_key = self._op_instance_key + copied._variable_instance_key = self._variable_instance_key + return copied def build_collective_reduce(input_tensors,