Parallel device: fix variable initialization in tf.function
Switches ParallelDevice variables to be compatible with the tf.function variable creator scope, and adds a special case to handle conditional initialization of parallel variables. Adds TPU tests for the parallel device since that's a major constraint on the implementation (no uninitialized input to tf.cond). Rolling forward with some branching logic for Windows (may not be Windows-specific, but whatever combination of packages we test with there). PiperOrigin-RevId: 334170699 Change-Id: I541655bd8a116d013a5a3f62b645aa7242411a40
This commit is contained in:
parent
e96a7098f1
commit
d44cb28478
@ -1,3 +1,5 @@
|
||||
load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -17,6 +19,7 @@ py_library(
|
||||
":saving",
|
||||
"//tensorflow/python:_pywrap_parallel_device",
|
||||
"//tensorflow/python/distribute:device_util",
|
||||
"//tensorflow/python/tpu:tpu_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -27,15 +30,13 @@ py_library(
|
||||
deps = ["//tensorflow/python:framework_ops"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
distribute_py_test(
|
||||
name = "parallel_device_test",
|
||||
srcs = ["parallel_device_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
# Dependencies aren't otherwise included in the pip package yet.
|
||||
"no_pip",
|
||||
# MRO broken; needs investigation
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
from tensorflow.python import _pywrap_parallel_device
|
||||
from tensorflow.python.distribute import device_util
|
||||
@ -32,6 +33,16 @@ from tensorflow.python.tpu.ops import tpu_ops
|
||||
_next_device_number = 0
|
||||
_next_device_number_lock = threading.Lock()
|
||||
|
||||
_all_parallel_devices = weakref.WeakValueDictionary()
|
||||
|
||||
|
||||
def unpack(tensor):
|
||||
"""Finds `tensor`'s parallel device and unpacks its components."""
|
||||
parallel_device = _all_parallel_devices.get(tensor.device, None)
|
||||
if parallel_device is None:
|
||||
raise ValueError("{} is not a parallel device".format(tensor.device))
|
||||
return parallel_device.unpack(tensor)
|
||||
|
||||
|
||||
# TODO(allenl): Expand this docstring once things like getting components on and
|
||||
# off the device are stable.
|
||||
@ -67,6 +78,7 @@ class ParallelDevice(object):
|
||||
self._device_ids = None
|
||||
self._device_scope = None
|
||||
self._saving_scope = None
|
||||
_all_parallel_devices[self._name] = self
|
||||
|
||||
def pack(self, tensors):
|
||||
"""Create a tensor on the parallel device from a sequence of tensors.
|
||||
|
@ -93,19 +93,30 @@ class _VirtualDeviceTestCase(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(_VirtualDeviceTestCase, self).setUp()
|
||||
cpus = context.context().list_physical_devices("CPU")
|
||||
# Set 4 virtual CPUs
|
||||
context.context().set_logical_device_configuration(cpus[0], [
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration()
|
||||
])
|
||||
ctx = context.context()
|
||||
if ctx.list_physical_devices("TPU"):
|
||||
self.device_type = "TPU"
|
||||
elif ctx.list_physical_devices("GPU"):
|
||||
self.device_type = "GPU"
|
||||
gpus = ctx.list_physical_devices(self.device_type)
|
||||
ctx.set_logical_device_configuration(gpus[0], [
|
||||
context.LogicalDeviceConfiguration(memory_limit=100),
|
||||
context.LogicalDeviceConfiguration(memory_limit=100),
|
||||
])
|
||||
else:
|
||||
self.device_type = "CPU"
|
||||
cpus = ctx.list_physical_devices("CPU")
|
||||
ctx.set_logical_device_configuration(cpus[0], [
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration(),
|
||||
])
|
||||
|
||||
self.device = parallel_device.ParallelDevice(
|
||||
components=["/job:localhost/device:CPU:0", "CPU:1"])
|
||||
self.assertIn("CPU:0", self.device.components[0])
|
||||
self.assertIn("CPU:1", self.device.components[1])
|
||||
self.device = parallel_device.ParallelDevice(components=[
|
||||
"/job:localhost/device:{}:0".format(self.device_type),
|
||||
self.device_type + ":1"
|
||||
])
|
||||
self.assertIn(self.device_type + ":0", self.device.components[0])
|
||||
self.assertIn(self.device_type + ":1", self.device.components[1])
|
||||
|
||||
|
||||
class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
@ -124,10 +135,14 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
def test_device_id(self):
|
||||
device_ids = self.device.unpack(self.device.device_ids)
|
||||
self.assertAllClose([0, 1], device_ids)
|
||||
self.assertIn(self.device.components[0], device_ids[0].backing_device)
|
||||
self.assertIn(self.device.components[1], device_ids[1].backing_device)
|
||||
# TODO(allenl): Should device IDs be int64 so they can be placed on GPUs?
|
||||
# Currently backing_device is CPU.
|
||||
self.assertIn(self.device.components[0], device_ids[0].device)
|
||||
self.assertIn(self.device.components[1], device_ids[1].device)
|
||||
|
||||
def test_collective_reduce(self):
|
||||
if self.device_type == "TPU":
|
||||
self.skipTest("ParallelDevice collectives on TPUs need work")
|
||||
with self.device:
|
||||
x = self.device.pack(
|
||||
[constant_op.constant(-1.5),
|
||||
@ -139,6 +154,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||
|
||||
def test_collective_reduce_async_scope(self):
|
||||
if self.device_type == "TPU":
|
||||
self.skipTest("ParallelDevice collectives on TPUs need work")
|
||||
# Note that ops on the parallel device currently don't execute
|
||||
# asynchronously. The test is just that we don't get deadlocks.
|
||||
with context.async_scope(), self.device:
|
||||
@ -152,6 +169,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||
|
||||
def test_collective_reduce_async_context(self):
|
||||
if self.device_type == "TPU":
|
||||
self.skipTest("ParallelDevice collectives on TPUs need work")
|
||||
previous = config.get_synchronous_execution()
|
||||
try:
|
||||
context._reset_context()
|
||||
@ -173,6 +192,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
config.set_synchronous_execution(previous)
|
||||
|
||||
def test_collective_in_function(self):
|
||||
if self.device_type == "TPU":
|
||||
self.skipTest("ParallelDevice collectives on TPUs need work")
|
||||
c = constant_op.constant([2])
|
||||
|
||||
@def_function.function
|
||||
@ -313,6 +334,33 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
return y, tape.gradient(y, x)
|
||||
self._assert_close_to_non_parallel(_test_fn)
|
||||
|
||||
def test_variable_created_in_function(self):
|
||||
|
||||
class M(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
self.v = None
|
||||
self.w = None
|
||||
self.x = None
|
||||
self.z = None
|
||||
|
||||
@def_function.function(autograph=False)
|
||||
def __call__(self, x):
|
||||
if self.v is None:
|
||||
with ops.init_scope():
|
||||
initial_value = constant_op.constant(2.)
|
||||
self.z = variables.Variable(initial_value)
|
||||
self.x = variables.Variable(initial_value)
|
||||
self.w = variables.Variable(lambda: constant_op.constant(2.))
|
||||
self.v = variables.Variable(constant_op.constant(2.))
|
||||
return x * self.v * self.w * self.x * self.z
|
||||
|
||||
with self.device:
|
||||
m = M()
|
||||
packed_outputs = m(array_ops.ones([]))
|
||||
outputs = self.device.unpack(packed_outputs)
|
||||
self.assertAllClose([16., 16.], outputs)
|
||||
|
||||
|
||||
class LayerTests(_VirtualDeviceTestCase):
|
||||
|
||||
@ -340,6 +388,8 @@ class LayerTests(_VirtualDeviceTestCase):
|
||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||
|
||||
def test_layer_sync_training(self):
|
||||
if self.device_type == "TPU":
|
||||
self.skipTest("ParallelDevice collectives on TPUs need work")
|
||||
with self.device:
|
||||
layer = _Dense(5)
|
||||
|
||||
@ -389,6 +439,8 @@ class LayerTests(_VirtualDeviceTestCase):
|
||||
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
|
||||
|
||||
def test_training_loop(self):
|
||||
if self.device_type == "TPU":
|
||||
self.skipTest("ParallelDevice collectives on TPUs need work")
|
||||
for _ in range(5):
|
||||
layer = _Dense(5)
|
||||
checkpoint = tracking.Checkpoint(layer=layer)
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import six
|
||||
import wrapt
|
||||
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -47,14 +49,32 @@ class _ParallelComponentSaveable(saveable_object.SaveableObject):
|
||||
resource=self._handle, value=restored_tensor)
|
||||
|
||||
|
||||
class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable):
|
||||
"""Mixin to to override variable checkpointing, saving each component."""
|
||||
_wrapt_type = type(wrapt.ObjectProxy)
|
||||
_variable_type = type(resource_variable_ops.BaseResourceVariable)
|
||||
if issubclass(_variable_type, _wrapt_type):
|
||||
# Some wrapt versions do not have a meta-class, which would create an invalid
|
||||
# MRO.
|
||||
VariableProxyMetaClass = _variable_type
|
||||
else:
|
||||
class VariableProxyMetaClass(_wrapt_type, _variable_type): # pylint: disable=duplicate-bases
|
||||
"""A combined MetaClasses for ParallelVariable.
|
||||
|
||||
def __init__(self, parallel_device, expected_shape=None, use_resource=None,
|
||||
**kwargs):
|
||||
del expected_shape, use_resource
|
||||
self._parallel_device = parallel_device
|
||||
super(ParallelSavingMixin, self).__init__(**kwargs)
|
||||
Satisfies the requirement "the metaclass of a derived class must be a
|
||||
(non-strict) subclass of the metaclasses of all its bases." At the time of
|
||||
writing these two MetaClasses are compatible (overriding different methods,
|
||||
both relatively trivial).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ParallelVariable(
|
||||
six.with_metaclass(VariableProxyMetaClass, wrapt.ObjectProxy,
|
||||
resource_variable_ops.BaseResourceVariable)):
|
||||
"""Overrides variable checkpointing, saving each component."""
|
||||
|
||||
def __init__(self, parallel_device, wrapped_variable):
|
||||
self._self_parallel_device = parallel_device
|
||||
super(ParallelVariable, self).__init__(wrapped_variable)
|
||||
|
||||
# TODO(allenl): Consider either adding a boolean argument for
|
||||
# save-primary-only or looking at synchronization/aggregation properties.
|
||||
@ -63,7 +83,8 @@ class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable):
|
||||
component_saveables = {}
|
||||
# Create one SaveableObject per device, each one of which looks like a
|
||||
# regular ResourceVariable saveable.
|
||||
for index, handle in enumerate(self._parallel_device.unpack(self.handle)):
|
||||
for index, handle in enumerate(
|
||||
self._self_parallel_device.unpack(self.handle)):
|
||||
if index == 0:
|
||||
# This is the name regular tf.Variables use to save. Using it for the
|
||||
# component on the first device means non-parallel tf.Variable objects
|
||||
@ -80,26 +101,24 @@ class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable):
|
||||
return component_saveables
|
||||
|
||||
|
||||
class ParallelVariable(
|
||||
ParallelSavingMixin, resource_variable_ops.ResourceVariable):
|
||||
pass
|
||||
|
||||
|
||||
class UninitializedParallelVariable(
|
||||
ParallelSavingMixin, resource_variable_ops.UninitializedVariable):
|
||||
pass
|
||||
|
||||
|
||||
def _variable_creator(next_creator, parallel_device, initial_value=None,
|
||||
**kwargs):
|
||||
del next_creator
|
||||
if initial_value is not None:
|
||||
def _variable_creator(next_creator, parallel_device, **kwargs):
|
||||
"""Wraps intercepted variables to add parallel saving."""
|
||||
# Depending on the context (SavedModel loading, tf.function, etc.) we may get
|
||||
# one of several different variable types. For variables placed on the
|
||||
# parallel device we only want to affect saving and otherwise preserve
|
||||
# behavior. This wrapping to override behavior is similar to tf.distribute's
|
||||
# DistributedVariable, but much more limited.
|
||||
variable = next_creator(**kwargs)
|
||||
if variable.device == parallel_device._name: # Friend access; pylint: disable=protected-access
|
||||
return ParallelVariable(
|
||||
parallel_device=parallel_device, initial_value=initial_value, **kwargs)
|
||||
parallel_device=parallel_device, wrapped_variable=variable)
|
||||
else:
|
||||
# SavedModel loading does not pass an initial value.
|
||||
return UninitializedParallelVariable(
|
||||
parallel_device=parallel_device, **kwargs)
|
||||
# Variables not placed on the handler (because of a device scope) don't
|
||||
# need wrapping.
|
||||
#
|
||||
# TODO(allenl): Device scopes should merge with parallel devices rather
|
||||
# than overriding them like this.
|
||||
return variable
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -836,6 +836,7 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:while_v2", # TODO(b/118513001): Imported via control_flow_ops; remove.
|
||||
"//tensorflow/python/distribute/parallel_device",
|
||||
"//tensorflow/python/profiler:trace",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
|
@ -27,9 +27,11 @@ import six
|
||||
from google.protobuf import text_format as _text_format
|
||||
from google.protobuf.message import DecodeError
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.distribute.parallel_device import parallel_device
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function as function_lib
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -430,6 +432,45 @@ def functions_run_eagerly():
|
||||
return RUN_FUNCTIONS_EAGERLY
|
||||
|
||||
|
||||
def _evaluate_var_is_initialized(variables):
|
||||
"""Compute booleans indicating whether each variable is initialized."""
|
||||
with ops.init_scope():
|
||||
var_is_initialized = []
|
||||
for v in variables:
|
||||
var_is_initialized.append(
|
||||
resource_variable_ops.var_is_initialized_op(v.handle))
|
||||
try:
|
||||
# Stack all the var_is_initialized values into one tensor and interpret
|
||||
# the numpy value. This will reduce the number of RPCs between client and
|
||||
# worker in the remote case.
|
||||
return array_ops.stack(var_is_initialized).numpy()
|
||||
except errors.UnimplementedError:
|
||||
# Some devices do not support implicit copy-off to host. Fall back to
|
||||
# variable-by-variable processing.
|
||||
for index, v in enumerate(variables):
|
||||
try:
|
||||
numpy_value = var_is_initialized[index].numpy()
|
||||
except errors.UnimplementedError:
|
||||
# This is a variable on a parallel device; we'll extract its value on
|
||||
# each replica and assert that they're identical.
|
||||
components = parallel_device.unpack(var_is_initialized[index])
|
||||
with ops.device(None):
|
||||
components = array_ops.stack(components)
|
||||
all_initialized = math_ops.reduce_all(components).numpy()
|
||||
any_initialized = math_ops.reduce_any(components).numpy()
|
||||
if all_initialized != any_initialized:
|
||||
raise NotImplementedError(
|
||||
("Some but not all components of a parallel variable {} were "
|
||||
"initialized between their creation in a tf.function and "
|
||||
"the function's trace having completed. This is not yet "
|
||||
"supported; consider initializing either all or none of the "
|
||||
"components, or moving initialization out of the function."
|
||||
).format(repr(v)))
|
||||
numpy_value = all_initialized
|
||||
var_is_initialized[index] = numpy_value
|
||||
return var_is_initialized
|
||||
|
||||
|
||||
class FunctionDeleter(object):
|
||||
|
||||
__slots__ = ["func_graph"]
|
||||
@ -1024,21 +1065,15 @@ class Function(object):
|
||||
if not initializers:
|
||||
return
|
||||
|
||||
var_is_initialized = _evaluate_var_is_initialized(
|
||||
[v for v, _ in initializers])
|
||||
|
||||
# Note: using defun here avoids an infinite recursion.
|
||||
# Most of the code in this function runs eagerly with init_scope, where
|
||||
# autograph is not necessary.
|
||||
@function_lib.defun(autograph=False)
|
||||
def initialize_variables():
|
||||
op_map = object_identity.ObjectIdentityDictionary()
|
||||
# Stack all the var_is_initialized values into one tensor and interpret
|
||||
# the numpy value. This will reduce the number of RPCs between client and
|
||||
# worker in the remote case.
|
||||
with ops.init_scope():
|
||||
var_is_initialized = []
|
||||
for v, _ in initializers:
|
||||
var_is_initialized.append(
|
||||
resource_variable_ops.var_is_initialized_op(v.handle))
|
||||
var_is_initialized = array_ops.stack(var_is_initialized).numpy()
|
||||
|
||||
inits = []
|
||||
for (v, init), is_initialized in zip(initializers, var_is_initialized):
|
||||
|
Loading…
Reference in New Issue
Block a user