Parallel device: fix variable initialization in tf.function

Switches ParallelDevice variables to be compatible with the tf.function variable creator scope, and adds a special case to handle conditional initialization of parallel variables.

Adds TPU tests for the parallel device since that's a major constraint on the implementation (no uninitialized input to tf.cond).

Rolling forward with some branching logic for Windows (may not be Windows-specific, but whatever combination of packages we test with there).

PiperOrigin-RevId: 334170699
Change-Id: I541655bd8a116d013a5a3f62b645aa7242411a40
This commit is contained in:
Allen Lavoie 2020-09-28 09:11:49 -07:00 committed by TensorFlower Gardener
parent e96a7098f1
commit d44cb28478
6 changed files with 172 additions and 52 deletions

View File

@ -1,3 +1,5 @@
load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test")
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
@ -17,6 +19,7 @@ py_library(
":saving",
"//tensorflow/python:_pywrap_parallel_device",
"//tensorflow/python/distribute:device_util",
"//tensorflow/python/tpu:tpu_ops",
],
)
@ -27,15 +30,13 @@ py_library(
deps = ["//tensorflow/python:framework_ops"],
)
py_test(
distribute_py_test(
name = "parallel_device_test",
srcs = ["parallel_device_test.py"],
python_version = "PY3",
tags = [
# Dependencies aren't otherwise included in the pip package yet.
"no_pip",
# MRO broken; needs investigation
"no_windows",
],
deps = [
":parallel_device",

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import threading
import weakref
from tensorflow.python import _pywrap_parallel_device
from tensorflow.python.distribute import device_util
@ -32,6 +33,16 @@ from tensorflow.python.tpu.ops import tpu_ops
_next_device_number = 0
_next_device_number_lock = threading.Lock()
_all_parallel_devices = weakref.WeakValueDictionary()
def unpack(tensor):
"""Finds `tensor`'s parallel device and unpacks its components."""
parallel_device = _all_parallel_devices.get(tensor.device, None)
if parallel_device is None:
raise ValueError("{} is not a parallel device".format(tensor.device))
return parallel_device.unpack(tensor)
# TODO(allenl): Expand this docstring once things like getting components on and
# off the device are stable.
@ -67,6 +78,7 @@ class ParallelDevice(object):
self._device_ids = None
self._device_scope = None
self._saving_scope = None
_all_parallel_devices[self._name] = self
def pack(self, tensors):
"""Create a tensor on the parallel device from a sequence of tensors.

View File

@ -93,19 +93,30 @@ class _VirtualDeviceTestCase(test.TestCase):
def setUp(self):
super(_VirtualDeviceTestCase, self).setUp()
cpus = context.context().list_physical_devices("CPU")
# Set 4 virtual CPUs
context.context().set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
ctx = context.context()
if ctx.list_physical_devices("TPU"):
self.device_type = "TPU"
elif ctx.list_physical_devices("GPU"):
self.device_type = "GPU"
gpus = ctx.list_physical_devices(self.device_type)
ctx.set_logical_device_configuration(gpus[0], [
context.LogicalDeviceConfiguration(memory_limit=100),
context.LogicalDeviceConfiguration(memory_limit=100),
])
else:
self.device_type = "CPU"
cpus = ctx.list_physical_devices("CPU")
ctx.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
])
self.device = parallel_device.ParallelDevice(
components=["/job:localhost/device:CPU:0", "CPU:1"])
self.assertIn("CPU:0", self.device.components[0])
self.assertIn("CPU:1", self.device.components[1])
self.device = parallel_device.ParallelDevice(components=[
"/job:localhost/device:{}:0".format(self.device_type),
self.device_type + ":1"
])
self.assertIn(self.device_type + ":0", self.device.components[0])
self.assertIn(self.device_type + ":1", self.device.components[1])
class ParallelDeviceTests(_VirtualDeviceTestCase):
@ -124,10 +135,14 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
def test_device_id(self):
device_ids = self.device.unpack(self.device.device_ids)
self.assertAllClose([0, 1], device_ids)
self.assertIn(self.device.components[0], device_ids[0].backing_device)
self.assertIn(self.device.components[1], device_ids[1].backing_device)
# TODO(allenl): Should device IDs be int64 so they can be placed on GPUs?
# Currently backing_device is CPU.
self.assertIn(self.device.components[0], device_ids[0].device)
self.assertIn(self.device.components[1], device_ids[1].device)
def test_collective_reduce(self):
if self.device_type == "TPU":
self.skipTest("ParallelDevice collectives on TPUs need work")
with self.device:
x = self.device.pack(
[constant_op.constant(-1.5),
@ -139,6 +154,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_collective_reduce_async_scope(self):
if self.device_type == "TPU":
self.skipTest("ParallelDevice collectives on TPUs need work")
# Note that ops on the parallel device currently don't execute
# asynchronously. The test is just that we don't get deadlocks.
with context.async_scope(), self.device:
@ -152,6 +169,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_collective_reduce_async_context(self):
if self.device_type == "TPU":
self.skipTest("ParallelDevice collectives on TPUs need work")
previous = config.get_synchronous_execution()
try:
context._reset_context()
@ -173,6 +192,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
config.set_synchronous_execution(previous)
def test_collective_in_function(self):
if self.device_type == "TPU":
self.skipTest("ParallelDevice collectives on TPUs need work")
c = constant_op.constant([2])
@def_function.function
@ -313,6 +334,33 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
return y, tape.gradient(y, x)
self._assert_close_to_non_parallel(_test_fn)
def test_variable_created_in_function(self):
class M(module.Module):
def __init__(self):
self.v = None
self.w = None
self.x = None
self.z = None
@def_function.function(autograph=False)
def __call__(self, x):
if self.v is None:
with ops.init_scope():
initial_value = constant_op.constant(2.)
self.z = variables.Variable(initial_value)
self.x = variables.Variable(initial_value)
self.w = variables.Variable(lambda: constant_op.constant(2.))
self.v = variables.Variable(constant_op.constant(2.))
return x * self.v * self.w * self.x * self.z
with self.device:
m = M()
packed_outputs = m(array_ops.ones([]))
outputs = self.device.unpack(packed_outputs)
self.assertAllClose([16., 16.], outputs)
class LayerTests(_VirtualDeviceTestCase):
@ -340,6 +388,8 @@ class LayerTests(_VirtualDeviceTestCase):
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_layer_sync_training(self):
if self.device_type == "TPU":
self.skipTest("ParallelDevice collectives on TPUs need work")
with self.device:
layer = _Dense(5)
@ -389,6 +439,8 @@ class LayerTests(_VirtualDeviceTestCase):
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
def test_training_loop(self):
if self.device_type == "TPU":
self.skipTest("ParallelDevice collectives on TPUs need work")
for _ in range(5):
layer = _Dense(5)
checkpoint = tracking.Checkpoint(layer=layer)

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import contextlib
import functools
import six
import wrapt
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import resource_variable_ops
@ -47,14 +49,32 @@ class _ParallelComponentSaveable(saveable_object.SaveableObject):
resource=self._handle, value=restored_tensor)
class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable):
"""Mixin to to override variable checkpointing, saving each component."""
_wrapt_type = type(wrapt.ObjectProxy)
_variable_type = type(resource_variable_ops.BaseResourceVariable)
if issubclass(_variable_type, _wrapt_type):
# Some wrapt versions do not have a meta-class, which would create an invalid
# MRO.
VariableProxyMetaClass = _variable_type
else:
class VariableProxyMetaClass(_wrapt_type, _variable_type): # pylint: disable=duplicate-bases
"""A combined MetaClasses for ParallelVariable.
def __init__(self, parallel_device, expected_shape=None, use_resource=None,
**kwargs):
del expected_shape, use_resource
self._parallel_device = parallel_device
super(ParallelSavingMixin, self).__init__(**kwargs)
Satisfies the requirement "the metaclass of a derived class must be a
(non-strict) subclass of the metaclasses of all its bases." At the time of
writing these two MetaClasses are compatible (overriding different methods,
both relatively trivial).
"""
pass
class ParallelVariable(
six.with_metaclass(VariableProxyMetaClass, wrapt.ObjectProxy,
resource_variable_ops.BaseResourceVariable)):
"""Overrides variable checkpointing, saving each component."""
def __init__(self, parallel_device, wrapped_variable):
self._self_parallel_device = parallel_device
super(ParallelVariable, self).__init__(wrapped_variable)
# TODO(allenl): Consider either adding a boolean argument for
# save-primary-only or looking at synchronization/aggregation properties.
@ -63,7 +83,8 @@ class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable):
component_saveables = {}
# Create one SaveableObject per device, each one of which looks like a
# regular ResourceVariable saveable.
for index, handle in enumerate(self._parallel_device.unpack(self.handle)):
for index, handle in enumerate(
self._self_parallel_device.unpack(self.handle)):
if index == 0:
# This is the name regular tf.Variables use to save. Using it for the
# component on the first device means non-parallel tf.Variable objects
@ -80,26 +101,24 @@ class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable):
return component_saveables
class ParallelVariable(
ParallelSavingMixin, resource_variable_ops.ResourceVariable):
pass
class UninitializedParallelVariable(
ParallelSavingMixin, resource_variable_ops.UninitializedVariable):
pass
def _variable_creator(next_creator, parallel_device, initial_value=None,
**kwargs):
del next_creator
if initial_value is not None:
def _variable_creator(next_creator, parallel_device, **kwargs):
"""Wraps intercepted variables to add parallel saving."""
# Depending on the context (SavedModel loading, tf.function, etc.) we may get
# one of several different variable types. For variables placed on the
# parallel device we only want to affect saving and otherwise preserve
# behavior. This wrapping to override behavior is similar to tf.distribute's
# DistributedVariable, but much more limited.
variable = next_creator(**kwargs)
if variable.device == parallel_device._name: # Friend access; pylint: disable=protected-access
return ParallelVariable(
parallel_device=parallel_device, initial_value=initial_value, **kwargs)
parallel_device=parallel_device, wrapped_variable=variable)
else:
# SavedModel loading does not pass an initial value.
return UninitializedParallelVariable(
parallel_device=parallel_device, **kwargs)
# Variables not placed on the handler (because of a device scope) don't
# need wrapping.
#
# TODO(allenl): Device scopes should merge with parallel devices rather
# than overriding them like this.
return variable
@contextlib.contextmanager

View File

@ -836,6 +836,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:while_v2", # TODO(b/118513001): Imported via control_flow_ops; remove.
"//tensorflow/python/distribute/parallel_device",
"//tensorflow/python/profiler:trace",
"//tensorflow/python/training/tracking:base",
],

View File

@ -27,9 +27,11 @@ import six
from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.distribute.parallel_device import parallel_device
from tensorflow.python.eager import context
from tensorflow.python.eager import function as function_lib
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -430,6 +432,45 @@ def functions_run_eagerly():
return RUN_FUNCTIONS_EAGERLY
def _evaluate_var_is_initialized(variables):
"""Compute booleans indicating whether each variable is initialized."""
with ops.init_scope():
var_is_initialized = []
for v in variables:
var_is_initialized.append(
resource_variable_ops.var_is_initialized_op(v.handle))
try:
# Stack all the var_is_initialized values into one tensor and interpret
# the numpy value. This will reduce the number of RPCs between client and
# worker in the remote case.
return array_ops.stack(var_is_initialized).numpy()
except errors.UnimplementedError:
# Some devices do not support implicit copy-off to host. Fall back to
# variable-by-variable processing.
for index, v in enumerate(variables):
try:
numpy_value = var_is_initialized[index].numpy()
except errors.UnimplementedError:
# This is a variable on a parallel device; we'll extract its value on
# each replica and assert that they're identical.
components = parallel_device.unpack(var_is_initialized[index])
with ops.device(None):
components = array_ops.stack(components)
all_initialized = math_ops.reduce_all(components).numpy()
any_initialized = math_ops.reduce_any(components).numpy()
if all_initialized != any_initialized:
raise NotImplementedError(
("Some but not all components of a parallel variable {} were "
"initialized between their creation in a tf.function and "
"the function's trace having completed. This is not yet "
"supported; consider initializing either all or none of the "
"components, or moving initialization out of the function."
).format(repr(v)))
numpy_value = all_initialized
var_is_initialized[index] = numpy_value
return var_is_initialized
class FunctionDeleter(object):
__slots__ = ["func_graph"]
@ -1024,21 +1065,15 @@ class Function(object):
if not initializers:
return
var_is_initialized = _evaluate_var_is_initialized(
[v for v, _ in initializers])
# Note: using defun here avoids an infinite recursion.
# Most of the code in this function runs eagerly with init_scope, where
# autograph is not necessary.
@function_lib.defun(autograph=False)
def initialize_variables():
op_map = object_identity.ObjectIdentityDictionary()
# Stack all the var_is_initialized values into one tensor and interpret
# the numpy value. This will reduce the number of RPCs between client and
# worker in the remote case.
with ops.init_scope():
var_is_initialized = []
for v, _ in initializers:
var_is_initialized.append(
resource_variable_ops.var_is_initialized_op(v.handle))
var_is_initialized = array_ops.stack(var_is_initialized).numpy()
inits = []
for (v, init), is_initialized in zip(initializers, var_is_initialized):