From f5547e8125d92b39331cbc73309eede3bd6b5f2e Mon Sep 17 00:00:00 2001 From: Yujing Zhang Date: Wed, 10 Jun 2020 14:33:04 -0700 Subject: [PATCH] Introduce PackedDistributedVariable which packs multiple variables distributed across devices. Introduce PackedVarAndDevice which represents a packed variable in a given device. PiperOrigin-RevId: 315769635 Change-Id: Ia63b72610afeb7139bd8370bc47067a1fb165307 --- tensorflow/python/distribute/BUILD | 31 ++ .../distribute/packed_distributed_variable.py | 338 ++++++++++++++++++ .../packed_distributed_variable_test.py | 107 ++++++ tensorflow/python/distribute/values.py | 12 + tensorflow/python/distribute/values_test.py | 23 ++ .../python/ops/resource_variable_ops.py | 3 + 6 files changed, 514 insertions(+) create mode 100644 tensorflow/python/distribute/packed_distributed_variable.py create mode 100644 tensorflow/python/distribute/packed_distributed_variable_test.py diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 977452cdad1..7451d5b0408 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -675,12 +675,25 @@ py_library( ], ) +py_library( + name = "packed_distributed_variable", + srcs = ["packed_distributed_variable.py"], + deps = [ + ":device_util", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/eager:context", + ], +) + py_library( name = "values", srcs = ["values.py"], deps = [ ":device_util", ":distribute_lib", + ":packed_distributed_variable", ":reduce_util", ":values_util", "//tensorflow/python:array_ops", @@ -1710,6 +1723,24 @@ py_library( deps = ["@six_archive//:six"], ) +py_test( + name = "packed_distributed_variable_test", + srcs = ["packed_distributed_variable_test.py"], + tags = [ + "nomac", #TODO(b/145922293): It would cause a Python segfault on macos + ], + deps = [ + ":device_util", + ":packed_distributed_variable", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", + ], +) + py_test( name = "multi_process_runner_test", srcs = ["multi_process_runner_test.py"], diff --git a/tensorflow/python/distribute/packed_distributed_variable.py b/tensorflow/python/distribute/packed_distributed_variable.py new file mode 100644 index 00000000000..62512cb4414 --- /dev/null +++ b/tensorflow/python/distribute/packed_distributed_variable.py @@ -0,0 +1,338 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A variable which packs a list of variables distributed across devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.distribute import device_util +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops + + +class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable): + """A variable which packs multiple variables distributed across devices. + + It's only supported when eager execution is enabled. + For op-by-op execution, use an unpacked handle on the current device; for + function execution, use the packed handle to reduce the overhead of function + calls. + """ + + def __init__(self, distributed_variables=None, name=None, **unused_kwargs): + """Packs a list of variables which are distributed across devices. + + Args: + distributed_variables: A list of distributed Variables to pack. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + """ + if not context.executing_eagerly(): + raise ValueError( + "PackedDistributedVariable should be created in eager mode.") + if not distributed_variables: + raise ValueError("Expect a non-empty list of variables to pack.") + for i, var in enumerate(distributed_variables): + if not resource_variable_ops.is_resource_variable(var): + raise ValueError("Expect a list of ResourceVariables to pack, " + "but the %d-th variable is %s" % (i, type(var))) + + self._distributed_variables = distributed_variables + self._devices = [v.device for v in distributed_variables] + with ops.init_scope(): + with ops.name_scope(name, "Variable", skip_on_eager=False) as name: + handle = ops.pack_eager_tensors( + [var.handle for var in distributed_variables]) + handle_name = ops.name_from_scope_name(name) + unique_id = "%s_%d" % (handle_name, ops.uid()) + super(PackedDistributedVariable, self).__init__( + trainable=distributed_variables[0].trainable, + shape=distributed_variables[0].shape, + dtype=distributed_variables[0].dtype, + handle=handle, + synchronization=distributed_variables[0].synchronization, + constraint=distributed_variables[0].constraint, + aggregation=distributed_variables[0].aggregation, + distribute_strategy=distributed_variables[0]._distribute_strategy, # pylint: disable=protected-access + name=name, + unique_id=unique_id, + handle_name=handle_name, + graph_element=None, + initial_value=None, + initializer_op=None, + is_initialized_op=None, + cached_value=None, + caching_device=None, + is_distributed_variables=True) + + @property + def devices(self): + return self._devices + + def get_var_on_device(self, device): + for i, d in enumerate(self._devices): + if d == device: + return self._distributed_variables[i] + raise ValueError("Device %s is not found" % device) + + def get_var_on_current_device(self): + current_device = device_util.canonicalize(device_util.current()) + return self.get_var_on_device(current_device) + + def initial_value(self, device): + """Returns the Tensor used as the initial value for the variable.""" + return self.get_var_on_device(device).initial_value + + @property + def handle(self): + return self._handle + + def _read_variable_op(self): + if context.executing_eagerly(): + return self.get_var_on_current_device().value() + else: + return super(PackedDistributedVariable, self)._read_variable_op() + + def value(self): + return self._read_variable_op() + + def is_initialized(self, name=None): + if context.executing_eagerly(): + result = self._distributed_variables[0].is_initialized() + for v in self._distributed_variables[1:-1]: + result = math_ops.logical_and(result, v.is_initialized()) + result = math_ops.logical_and( + result, self._distributed_variables[-1].is_initialized(), name=name) + else: + with ops.device(self._devices[0]): + result = super(PackedDistributedVariable, self).is_initialized(name) + for d in self._devices[1:-1]: + with ops.device(d): + initialized = super(PackedDistributedVariable, + self).is_initialized(name) + result = math_ops.logical_and(result, initialized) + with ops.device(self._devices[-1]): + initialized = super(PackedDistributedVariable, + self).is_initialized(name) + result = math_ops.logical_and(result, initialized, name=name) + return result + + def _update(self, update_fn, value, **kwargs): + if context.executing_eagerly(): + return update_fn(self.get_var_on_current_device(), value, **kwargs) + else: + return update_fn(super(PackedDistributedVariable, self), value, **kwargs) + + def assign_sub(self, delta, use_locking=None, name=None, read_value=True): + assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) + return self._update( + update_fn=assign_sub_fn, + value=delta, + use_locking=use_locking, + name=name, + read_value=read_value) + + def assign_add(self, delta, use_locking=None, name=None, read_value=True): + assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) + return self._update( + update_fn=assign_add_fn, + value=delta, + use_locking=use_locking, + name=name, + read_value=read_value) + + def assign(self, value, use_locking=None, name=None, read_value=True): + assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) + return self._update( + update_fn=assign_fn, + value=value, + use_locking=use_locking, + name=name, + read_value=read_value) + + def scatter_sub(self, sparse_delta, use_locking=False, name=None): + scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) + return self._update( + update_fn=scatter_sub_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + def scatter_add(self, sparse_delta, use_locking=False, name=None): + scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) + return self._update( + update_fn=scatter_add_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + def scatter_mul(self, sparse_delta, use_locking=False, name=None): + scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) + return self._update( + update_fn=scatter_mul_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + def scatter_div(self, sparse_delta, use_locking=False, name=None): + scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) + return self._update( + update_fn=scatter_div_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + def scatter_min(self, sparse_delta, use_locking=False, name=None): + scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) + return self._update( + update_fn=scatter_min_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + def scatter_max(self, sparse_delta, use_locking=False, name=None): + scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) + return self._update( + update_fn=scatter_max_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + def scatter_update(self, sparse_delta, use_locking=False, name=None): + scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) + return self._update( + update_fn=scatter_update_fn, + value=sparse_delta, + use_locking=use_locking, + name=name) + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + if context.executing_eagerly(): + return self.get_var_on_current_device()._dense_var_to_tensor( # pylint: disable=protected-access + dtype=dtype, + name=name, + as_ref=as_ref) + else: + return super(PackedDistributedVariable, self)._dense_var_to_tensor( # pylint: disable=protected-access + dtype=dtype, + name=name, + as_ref=as_ref) + + +class PackedVarAndDevice(object): + """Holds a packed distributed variable and a device.""" + + def __init__(self, var, device): + self._var = var + self._device = device + + def __getattr__(self, name): + return getattr(self._var, name) + + def var(self): + return self._var + + def value(self): + with ops.device(self._device): + return self._var.value() + + def read_value(self): + with ops.device(self._device): + return self._var.read_value() + + @property + def initial_value(self): + return self._var.initial_value(self._device) + + def initialized_value(self): + with ops.device(self._device): + return self._var.initialized_value() + + @property + def device(self): + return self._device + + @property + def handle(self): + return self._var.handle + + @property + def op(self): + with ops.device(self._device): + return self._var.op + + def assign_sub(self, delta, use_locking=None, name=None, read_value=True): + with ops.device(self._device): + return self._var.assign_sub(delta, use_locking, name, read_value) + + def assign_add(self, delta, use_locking=None, name=None, read_value=True): + with ops.device(self._device): + return self._var.assign_add(delta, use_locking, name, read_value) + + def assign(self, value, use_locking=None, name=None, read_value=True): + with ops.device(self._device): + return self._var.assign(value, use_locking, name, read_value) + + def scatter_sub(self, sparse_delta, use_locking=False, name=None): + with ops.device(self._device): + return self._var.scatter_sub(sparse_delta, use_locking, name) + + def scatter_add(self, sparse_delta, use_locking=False, name=None): + with ops.device(self._device): + return self._var.scatter_add(sparse_delta, use_locking, name) + + def scatter_mul(self, sparse_delta, use_locking=False, name=None): + with ops.device(self._device): + return self._var.scatter_mul(sparse_delta, use_locking, name) + + def scatter_div(self, sparse_delta, use_locking=False, name=None): + with ops.device(self._device): + return self._var.scatter_div(sparse_delta, use_locking, name) + + def scatter_min(self, sparse_delta, use_locking=False, name=None): + with ops.device(self._device): + return self._var.scatter_min(sparse_delta, use_locking, name) + + def scatter_max(self, sparse_delta, use_locking=False, name=None): + with ops.device(self._device): + return self._var.scatter_max(sparse_delta, use_locking, name) + + def scatter_update(self, sparse_delta, use_locking=False, name=None): + with ops.device(self._device): + return self._var.scatter_update(sparse_delta, use_locking, name) + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + with ops.device(self._device): + return self._var._dense_var_to_tensor( # pylint: disable=protected-access + dtype=dtype, + name=name, + as_ref=as_ref) + + def _as_graph_element(self): + return self._var._as_graph_element() # pylint: disable=protected-access + + +def _tensor_conversion_packed_var_and_device(var, + dtype=None, + name=None, + as_ref=False): + return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access + + +ops.register_tensor_conversion_function( + PackedVarAndDevice, _tensor_conversion_packed_var_and_device) diff --git a/tensorflow/python/distribute/packed_distributed_variable_test.py b/tensorflow/python/distribute/packed_distributed_variable_test.py new file mode 100644 index 00000000000..d29d19960a5 --- /dev/null +++ b/tensorflow/python/distribute/packed_distributed_variable_test.py @@ -0,0 +1,107 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import packed_distributed_variable +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.framework import config +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + + +class PackedDistributedVariableTest(test.TestCase): + + def setUp(self): + super(PackedDistributedVariableTest, self).setUp() + cpus = config.list_physical_devices('CPU') + # Set 2 virtual CPUs + config.set_logical_device_configuration(cpus[0], [ + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration(), + ]) + + def testPackedVariable(self): + with ops.device('/cpu:0'): + v0 = resource_variable_ops.ResourceVariable(1.0, name='var0') + with ops.device('/cpu:1'): + v1 = resource_variable_ops.ResourceVariable(2.0, name='var1') + + packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1]) + self.assertTrue(packed_var.handle.is_packed) + self.assertTrue(packed_var.is_initialized) + + with ops.device('/cpu:0'): + self.assertAllEqual(packed_var.get_var_on_current_device(), v0) + val0 = packed_var.assign(2.0).assign_add(1.0) + self.assertAllEqual(val0, 3.0) + + with ops.device('/cpu:1'): + self.assertAllEqual(packed_var.get_var_on_current_device(), v1) + val0 = packed_var.assign(2.0).assign_add(1.0) + self.assertAllEqual(val0, 3.0) + + @def_function.function + def update_var(): + with ops.device('/cpu:0'): + packed_var.assign_add(3.0).assign_sub(1.0) + read0 = packed_var.value() + with ops.device('/cpu:1'): + packed_var.assign_sub(4.0).assign_sub(2.0) + read1 = packed_var.value() + + return read0, read1 + + self.assertAllEqual(update_var(), (5.0, -3.0)) + + def testPackedVarAndDevice(self): + device0 = device_util.canonicalize('/cpu:0') + device1 = device_util.canonicalize('/cpu:1') + + with ops.device(device0): + v0 = resource_variable_ops.ResourceVariable(1.0) + with ops.device(device1): + v1 = resource_variable_ops.ResourceVariable(2.0) + + packed_var = packed_distributed_variable.PackedDistributedVariable([v0, v1]) + + packed_var0 = packed_distributed_variable.PackedVarAndDevice( + packed_var, device0) + self.assertTrue(packed_var0.handle.is_packed) + self.assertAllEqual(math_ops.mul(packed_var0, 2.0), 2.0) + + packed_var1 = packed_distributed_variable.PackedVarAndDevice( + packed_var, device1) + self.assertAllEqual(packed_var1.assign(3.0), 3.0) + + @def_function.function + def func(): + var0 = packed_distributed_variable.PackedVarAndDevice(packed_var, device0) + var0.assign_add(3.0) + var1 = packed_distributed_variable.PackedVarAndDevice(packed_var, device1) + return var0.value(), math_ops.add(var1, 2.0) + + self.assertAllEqual(func(), (4.0, 5.0)) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index dc3aac57e38..90210e9041e 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import packed_distributed_variable as packed from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values_util from tensorflow.python.eager import context @@ -419,6 +420,17 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, self._aggregation = aggregation super(DistributedVariable, self).__init__(values) self._common_name = self._primary.name.split(":")[0] + + # Packed variable is used to reduce the overhead of function execution. + # For a DistributedVariable, only one variable handle is captured into a + # function graph. It's only supported in eager mode. + if ops.executing_eagerly_outside_functions() and getattr( + strategy, "_enable_packed_variable_in_eager_mode", False): + name = "%s/packed/" % self._common_name + self._packed_var = packed.PackedDistributedVariable(values, name=name) + else: + self._packed_var = None + # tf.keras keeps track of variables initialized using this attribute. When # tf.keras gets the default session, it initializes all uninitialized vars. # We need to make _keras_initialized a member of DistributedVariable because diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index e91d89abc7d..8ac779f17c0 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -31,6 +31,7 @@ from tensorflow.python import tf2 from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import packed_distributed_variable as packed from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_values @@ -40,6 +41,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import device from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops @@ -703,6 +705,27 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): self.evaluate( distribution.experimental_local_results(distribution.run(assign))) + def testPackedVariable(self, distribution, synchronization, aggregation): + with distribution.scope(): + v0 = variables_lib.Variable( + 0., synchronization=synchronization, aggregation=aggregation) + self.assertEqual(v0._packed_var, None) + + device_type = device.DeviceSpec.from_string(v0._devices[0]).device_type + for d in v0._devices: + if device.DeviceSpec.from_string(d).device_type != device_type: + self.skipTest("Packing variables on devices of different types " + "is not supported yet.") + + distribution._enable_packed_variable_in_eager_mode = True + with distribution.scope(): + v1 = variables_lib.Variable( + 0., synchronization=synchronization, aggregation=aggregation) + if ops.executing_eagerly_outside_functions(): + self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable) + else: + self.assertEqual(v1._packed_var, None) + class MirroredVariableTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index a9e146595ce..25f6347f034 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -281,6 +281,9 @@ class EagerResourceDeleter(object): # Resources follow object-identity when executing eagerly, so it is safe to # delete the resource we have a handle to. try: + # A packed EagerTensor doesn't own any resource. + if isinstance(self._handle, ops.EagerTensor) and self._handle.is_packed: + return # This resource was created in eager mode. However, this destructor may be # running in graph mode (especially during unit tests). To clean up # successfully, we switch back into eager mode temporarily.