From d44cb28478dcde4e6516556898fc85994c2ffded Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 28 Sep 2020 09:11:49 -0700 Subject: [PATCH] Parallel device: fix variable initialization in tf.function Switches ParallelDevice variables to be compatible with the tf.function variable creator scope, and adds a special case to handle conditional initialization of parallel variables. Adds TPU tests for the parallel device since that's a major constraint on the implementation (no uninitialized input to tf.cond). Rolling forward with some branching logic for Windows (may not be Windows-specific, but whatever combination of packages we test with there). PiperOrigin-RevId: 334170699 Change-Id: I541655bd8a116d013a5a3f62b645aa7242411a40 --- .../python/distribute/parallel_device/BUILD | 7 +- .../parallel_device/parallel_device.py | 12 +++ .../parallel_device/parallel_device_test.py | 80 +++++++++++++++---- .../distribute/parallel_device/saving.py | 71 ++++++++++------ tensorflow/python/eager/BUILD | 1 + tensorflow/python/eager/def_function.py | 53 +++++++++--- 6 files changed, 172 insertions(+), 52 deletions(-) diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD index 5fc294a5f5c..331e0a3c3af 100644 --- a/tensorflow/python/distribute/parallel_device/BUILD +++ b/tensorflow/python/distribute/parallel_device/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test") + package( default_visibility = ["//tensorflow:internal"], licenses = ["notice"], # Apache 2.0 @@ -17,6 +19,7 @@ py_library( ":saving", "//tensorflow/python:_pywrap_parallel_device", "//tensorflow/python/distribute:device_util", + "//tensorflow/python/tpu:tpu_ops", ], ) @@ -27,15 +30,13 @@ py_library( deps = ["//tensorflow/python:framework_ops"], ) -py_test( +distribute_py_test( name = "parallel_device_test", srcs = ["parallel_device_test.py"], python_version = "PY3", tags = [ # Dependencies aren't otherwise included in the pip package yet. "no_pip", - # MRO broken; needs investigation - "no_windows", ], deps = [ ":parallel_device", diff --git a/tensorflow/python/distribute/parallel_device/parallel_device.py b/tensorflow/python/distribute/parallel_device/parallel_device.py index 30381e2a95d..218bd68d824 100644 --- a/tensorflow/python/distribute/parallel_device/parallel_device.py +++ b/tensorflow/python/distribute/parallel_device/parallel_device.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import threading +import weakref from tensorflow.python import _pywrap_parallel_device from tensorflow.python.distribute import device_util @@ -32,6 +33,16 @@ from tensorflow.python.tpu.ops import tpu_ops _next_device_number = 0 _next_device_number_lock = threading.Lock() +_all_parallel_devices = weakref.WeakValueDictionary() + + +def unpack(tensor): + """Finds `tensor`'s parallel device and unpacks its components.""" + parallel_device = _all_parallel_devices.get(tensor.device, None) + if parallel_device is None: + raise ValueError("{} is not a parallel device".format(tensor.device)) + return parallel_device.unpack(tensor) + # TODO(allenl): Expand this docstring once things like getting components on and # off the device are stable. @@ -67,6 +78,7 @@ class ParallelDevice(object): self._device_ids = None self._device_scope = None self._saving_scope = None + _all_parallel_devices[self._name] = self def pack(self, tensors): """Create a tensor on the parallel device from a sequence of tensors. diff --git a/tensorflow/python/distribute/parallel_device/parallel_device_test.py b/tensorflow/python/distribute/parallel_device/parallel_device_test.py index 066f9ea376c..cf86a7362fb 100644 --- a/tensorflow/python/distribute/parallel_device/parallel_device_test.py +++ b/tensorflow/python/distribute/parallel_device/parallel_device_test.py @@ -93,19 +93,30 @@ class _VirtualDeviceTestCase(test.TestCase): def setUp(self): super(_VirtualDeviceTestCase, self).setUp() - cpus = context.context().list_physical_devices("CPU") - # Set 4 virtual CPUs - context.context().set_logical_device_configuration(cpus[0], [ - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration() - ]) + ctx = context.context() + if ctx.list_physical_devices("TPU"): + self.device_type = "TPU" + elif ctx.list_physical_devices("GPU"): + self.device_type = "GPU" + gpus = ctx.list_physical_devices(self.device_type) + ctx.set_logical_device_configuration(gpus[0], [ + context.LogicalDeviceConfiguration(memory_limit=100), + context.LogicalDeviceConfiguration(memory_limit=100), + ]) + else: + self.device_type = "CPU" + cpus = ctx.list_physical_devices("CPU") + ctx.set_logical_device_configuration(cpus[0], [ + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration(), + ]) - self.device = parallel_device.ParallelDevice( - components=["/job:localhost/device:CPU:0", "CPU:1"]) - self.assertIn("CPU:0", self.device.components[0]) - self.assertIn("CPU:1", self.device.components[1]) + self.device = parallel_device.ParallelDevice(components=[ + "/job:localhost/device:{}:0".format(self.device_type), + self.device_type + ":1" + ]) + self.assertIn(self.device_type + ":0", self.device.components[0]) + self.assertIn(self.device_type + ":1", self.device.components[1]) class ParallelDeviceTests(_VirtualDeviceTestCase): @@ -124,10 +135,14 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): def test_device_id(self): device_ids = self.device.unpack(self.device.device_ids) self.assertAllClose([0, 1], device_ids) - self.assertIn(self.device.components[0], device_ids[0].backing_device) - self.assertIn(self.device.components[1], device_ids[1].backing_device) + # TODO(allenl): Should device IDs be int64 so they can be placed on GPUs? + # Currently backing_device is CPU. + self.assertIn(self.device.components[0], device_ids[0].device) + self.assertIn(self.device.components[1], device_ids[1].device) def test_collective_reduce(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") with self.device: x = self.device.pack( [constant_op.constant(-1.5), @@ -139,6 +154,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): self.assertIn(self.device.components[1], outputs[1].backing_device) def test_collective_reduce_async_scope(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") # Note that ops on the parallel device currently don't execute # asynchronously. The test is just that we don't get deadlocks. with context.async_scope(), self.device: @@ -152,6 +169,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): self.assertIn(self.device.components[1], outputs[1].backing_device) def test_collective_reduce_async_context(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") previous = config.get_synchronous_execution() try: context._reset_context() @@ -173,6 +192,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): config.set_synchronous_execution(previous) def test_collective_in_function(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") c = constant_op.constant([2]) @def_function.function @@ -313,6 +334,33 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): return y, tape.gradient(y, x) self._assert_close_to_non_parallel(_test_fn) + def test_variable_created_in_function(self): + + class M(module.Module): + + def __init__(self): + self.v = None + self.w = None + self.x = None + self.z = None + + @def_function.function(autograph=False) + def __call__(self, x): + if self.v is None: + with ops.init_scope(): + initial_value = constant_op.constant(2.) + self.z = variables.Variable(initial_value) + self.x = variables.Variable(initial_value) + self.w = variables.Variable(lambda: constant_op.constant(2.)) + self.v = variables.Variable(constant_op.constant(2.)) + return x * self.v * self.w * self.x * self.z + + with self.device: + m = M() + packed_outputs = m(array_ops.ones([])) + outputs = self.device.unpack(packed_outputs) + self.assertAllClose([16., 16.], outputs) + class LayerTests(_VirtualDeviceTestCase): @@ -340,6 +388,8 @@ class LayerTests(_VirtualDeviceTestCase): self.assertIn(self.device.components[1], outputs[1].backing_device) def test_layer_sync_training(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") with self.device: layer = _Dense(5) @@ -389,6 +439,8 @@ class LayerTests(_VirtualDeviceTestCase): self.assertIn(self.device.components[1], final_kernels[1].backing_device) def test_training_loop(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") for _ in range(5): layer = _Dense(5) checkpoint = tracking.Checkpoint(layer=layer) diff --git a/tensorflow/python/distribute/parallel_device/saving.py b/tensorflow/python/distribute/parallel_device/saving.py index f1539e49651..5fdd7ae5d3a 100644 --- a/tensorflow/python/distribute/parallel_device/saving.py +++ b/tensorflow/python/distribute/parallel_device/saving.py @@ -20,6 +20,8 @@ from __future__ import print_function import contextlib import functools +import six +import wrapt from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import resource_variable_ops @@ -47,14 +49,32 @@ class _ParallelComponentSaveable(saveable_object.SaveableObject): resource=self._handle, value=restored_tensor) -class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable): - """Mixin to to override variable checkpointing, saving each component.""" +_wrapt_type = type(wrapt.ObjectProxy) +_variable_type = type(resource_variable_ops.BaseResourceVariable) +if issubclass(_variable_type, _wrapt_type): + # Some wrapt versions do not have a meta-class, which would create an invalid + # MRO. + VariableProxyMetaClass = _variable_type +else: + class VariableProxyMetaClass(_wrapt_type, _variable_type): # pylint: disable=duplicate-bases + """A combined MetaClasses for ParallelVariable. - def __init__(self, parallel_device, expected_shape=None, use_resource=None, - **kwargs): - del expected_shape, use_resource - self._parallel_device = parallel_device - super(ParallelSavingMixin, self).__init__(**kwargs) + Satisfies the requirement "the metaclass of a derived class must be a + (non-strict) subclass of the metaclasses of all its bases." At the time of + writing these two MetaClasses are compatible (overriding different methods, + both relatively trivial). + """ + pass + + +class ParallelVariable( + six.with_metaclass(VariableProxyMetaClass, wrapt.ObjectProxy, + resource_variable_ops.BaseResourceVariable)): + """Overrides variable checkpointing, saving each component.""" + + def __init__(self, parallel_device, wrapped_variable): + self._self_parallel_device = parallel_device + super(ParallelVariable, self).__init__(wrapped_variable) # TODO(allenl): Consider either adding a boolean argument for # save-primary-only or looking at synchronization/aggregation properties. @@ -63,7 +83,8 @@ class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable): component_saveables = {} # Create one SaveableObject per device, each one of which looks like a # regular ResourceVariable saveable. - for index, handle in enumerate(self._parallel_device.unpack(self.handle)): + for index, handle in enumerate( + self._self_parallel_device.unpack(self.handle)): if index == 0: # This is the name regular tf.Variables use to save. Using it for the # component on the first device means non-parallel tf.Variable objects @@ -80,26 +101,24 @@ class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable): return component_saveables -class ParallelVariable( - ParallelSavingMixin, resource_variable_ops.ResourceVariable): - pass - - -class UninitializedParallelVariable( - ParallelSavingMixin, resource_variable_ops.UninitializedVariable): - pass - - -def _variable_creator(next_creator, parallel_device, initial_value=None, - **kwargs): - del next_creator - if initial_value is not None: +def _variable_creator(next_creator, parallel_device, **kwargs): + """Wraps intercepted variables to add parallel saving.""" + # Depending on the context (SavedModel loading, tf.function, etc.) we may get + # one of several different variable types. For variables placed on the + # parallel device we only want to affect saving and otherwise preserve + # behavior. This wrapping to override behavior is similar to tf.distribute's + # DistributedVariable, but much more limited. + variable = next_creator(**kwargs) + if variable.device == parallel_device._name: # Friend access; pylint: disable=protected-access return ParallelVariable( - parallel_device=parallel_device, initial_value=initial_value, **kwargs) + parallel_device=parallel_device, wrapped_variable=variable) else: - # SavedModel loading does not pass an initial value. - return UninitializedParallelVariable( - parallel_device=parallel_device, **kwargs) + # Variables not placed on the handler (because of a device scope) don't + # need wrapping. + # + # TODO(allenl): Device scopes should merge with parallel devices rather + # than overriding them like this. + return variable @contextlib.contextmanager diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 766adf5eecc..cf96feb7778 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -836,6 +836,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:while_v2", # TODO(b/118513001): Imported via control_flow_ops; remove. + "//tensorflow/python/distribute/parallel_device", "//tensorflow/python/profiler:trace", "//tensorflow/python/training/tracking:base", ], diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 08908f71cec..2e667884751 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -27,9 +27,11 @@ import six from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.distribute.parallel_device import parallel_device from tensorflow.python.eager import context from tensorflow.python.eager import function as function_lib from tensorflow.python.eager import lift_to_graph +from tensorflow.python.framework import errors from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -430,6 +432,45 @@ def functions_run_eagerly(): return RUN_FUNCTIONS_EAGERLY +def _evaluate_var_is_initialized(variables): + """Compute booleans indicating whether each variable is initialized.""" + with ops.init_scope(): + var_is_initialized = [] + for v in variables: + var_is_initialized.append( + resource_variable_ops.var_is_initialized_op(v.handle)) + try: + # Stack all the var_is_initialized values into one tensor and interpret + # the numpy value. This will reduce the number of RPCs between client and + # worker in the remote case. + return array_ops.stack(var_is_initialized).numpy() + except errors.UnimplementedError: + # Some devices do not support implicit copy-off to host. Fall back to + # variable-by-variable processing. + for index, v in enumerate(variables): + try: + numpy_value = var_is_initialized[index].numpy() + except errors.UnimplementedError: + # This is a variable on a parallel device; we'll extract its value on + # each replica and assert that they're identical. + components = parallel_device.unpack(var_is_initialized[index]) + with ops.device(None): + components = array_ops.stack(components) + all_initialized = math_ops.reduce_all(components).numpy() + any_initialized = math_ops.reduce_any(components).numpy() + if all_initialized != any_initialized: + raise NotImplementedError( + ("Some but not all components of a parallel variable {} were " + "initialized between their creation in a tf.function and " + "the function's trace having completed. This is not yet " + "supported; consider initializing either all or none of the " + "components, or moving initialization out of the function." + ).format(repr(v))) + numpy_value = all_initialized + var_is_initialized[index] = numpy_value + return var_is_initialized + + class FunctionDeleter(object): __slots__ = ["func_graph"] @@ -1024,21 +1065,15 @@ class Function(object): if not initializers: return + var_is_initialized = _evaluate_var_is_initialized( + [v for v, _ in initializers]) + # Note: using defun here avoids an infinite recursion. # Most of the code in this function runs eagerly with init_scope, where # autograph is not necessary. @function_lib.defun(autograph=False) def initialize_variables(): op_map = object_identity.ObjectIdentityDictionary() - # Stack all the var_is_initialized values into one tensor and interpret - # the numpy value. This will reduce the number of RPCs between client and - # worker in the remote case. - with ops.init_scope(): - var_is_initialized = [] - for v, _ in initializers: - var_is_initialized.append( - resource_variable_ops.var_is_initialized_op(v.handle)) - var_is_initialized = array_ops.stack(var_is_initialized).numpy() inits = [] for (v, init), is_initialized in zip(initializers, var_is_initialized):