Switches ParallelDevice variables to be compatible with the tf.function variable creator scope, and adds a special case to handle conditional initialization of parallel variables. Adds TPU tests for the parallel device since that's a major constraint on the implementation (no uninitialized input to tf.cond). Rolling forward with some branching logic for Windows (may not be Windows-specific, but whatever combination of packages we test with there). PiperOrigin-RevId: 334170699 Change-Id: I541655bd8a116d013a5a3f62b645aa7242411a40
172 lines
6.8 KiB
Python
172 lines
6.8 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utility for eagerly executing operations in parallel on multiple devices."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import threading
|
|
import weakref
|
|
|
|
from tensorflow.python import _pywrap_parallel_device
|
|
from tensorflow.python.distribute import device_util
|
|
from tensorflow.python.distribute.parallel_device import saving
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.tpu.ops import tpu_ops
|
|
|
|
_next_device_number = 0
|
|
_next_device_number_lock = threading.Lock()
|
|
|
|
_all_parallel_devices = weakref.WeakValueDictionary()
|
|
|
|
|
|
def unpack(tensor):
|
|
"""Finds `tensor`'s parallel device and unpacks its components."""
|
|
parallel_device = _all_parallel_devices.get(tensor.device, None)
|
|
if parallel_device is None:
|
|
raise ValueError("{} is not a parallel device".format(tensor.device))
|
|
return parallel_device.unpack(tensor)
|
|
|
|
|
|
# TODO(allenl): Expand this docstring once things like getting components on and
|
|
# off the device are stable.
|
|
#
|
|
# TODO(allenl): Make multi-client work; we need an offset for device IDs, and an
|
|
# indication of how many other devices there are total for collectives which
|
|
# don't have a number of participants hard-coded in their attributes.
|
|
class ParallelDevice(object):
|
|
"""A device which executes operations in parallel."""
|
|
|
|
def __init__(self, components):
|
|
"""Creates a device which executes operations in parallel on `components`.
|
|
|
|
Args:
|
|
components: A list of device names. Each operation executed on the
|
|
returned device executes on these component devices.
|
|
|
|
Returns:
|
|
A string with the name of the newly created device.
|
|
"""
|
|
global _next_device_number, _next_device_number_lock
|
|
self.components = tuple(device_util.canonicalize(d) for d in components)
|
|
ctx = context.context()
|
|
with _next_device_number_lock:
|
|
# TODO(allenl): Better names for parallel devices (right now "CUSTOM" is
|
|
# special-cased).
|
|
self._name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(),
|
|
_next_device_number)
|
|
_next_device_number += 1
|
|
device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
|
|
self._name, self.components)
|
|
context.register_custom_device(device, self._name, device_info)
|
|
self._device_ids = None
|
|
self._device_scope = None
|
|
self._saving_scope = None
|
|
_all_parallel_devices[self._name] = self
|
|
|
|
def pack(self, tensors):
|
|
"""Create a tensor on the parallel device from a sequence of tensors.
|
|
|
|
Args:
|
|
tensors: A flat list of tensors, one per device in `self.components`.
|
|
|
|
Returns:
|
|
A single tensor placed on the ParallelDevice.
|
|
"""
|
|
self._assert_eager()
|
|
with ops.device(self._name):
|
|
return tpu_ops.tpu_replicated_input(inputs=tensors)
|
|
|
|
def unpack(self, parallel_tensor):
|
|
"""Unpack a parallel tensor into its components.
|
|
|
|
Args:
|
|
parallel_tensor: A tensor placed on the ParallelDevice.
|
|
|
|
Returns:
|
|
A flat list of tensors, one per `self.components`.
|
|
"""
|
|
self._assert_eager()
|
|
with ops.device(self._name):
|
|
return tpu_ops.tpu_replicated_output(
|
|
parallel_tensor, num_replicas=len(self.components))
|
|
|
|
@property
|
|
def device_ids(self):
|
|
"""A parallel tensor with scalar integers numbering component devices.
|
|
|
|
Each device ID is placed on its corresponding device, in the same order as
|
|
the `components` constructor argument.
|
|
|
|
Returns:
|
|
A parallel tensor containing 0 on the first device, 1 on the second, etc.
|
|
"""
|
|
if self._device_ids is None:
|
|
# device_ids may be called from inside a tf.function, in which case the
|
|
# function captures the eager tensor. We can't pack tensors in a function
|
|
# at the moment, and even if we could we don't want to hold on to a
|
|
# symbolic tensor, so we need to init_scope out of the function
|
|
# temporarily.
|
|
with ops.init_scope():
|
|
# TODO(allenl): Functions which capture eager device ID tensors won't be
|
|
# saveable in SavedModels. Ideally we'd run a DeviceID op every time
|
|
# device IDs are required, with functions using the op in their bodies
|
|
# but not hard-coding a fixed number of devices (so they can be re-used
|
|
# with a different replica count).
|
|
device_ids_list = []
|
|
for index, device in enumerate(self.components):
|
|
with ops.device(device):
|
|
# The identity op ensures each device ID tensor is placed on its
|
|
# device.
|
|
device_ids_list.append(
|
|
array_ops.identity(constant_op.constant(index)))
|
|
self._device_ids = self.pack(device_ids_list)
|
|
|
|
return self._device_ids
|
|
|
|
def _assert_eager(self):
|
|
"""Verifies that tracing is not active."""
|
|
if not context.executing_eagerly():
|
|
raise NotImplementedError(
|
|
"ParallelDevice is currently not supported inside `tf.function`. It "
|
|
"can however run calls to a `tf.function` in parallel:\n\n"
|
|
"with ParallelDevice() as p:\n f()")
|
|
|
|
def __enter__(self):
|
|
"""Runs ops in parallel, makes variables which save independent buffers."""
|
|
if (self._device_scope is not None or self._saving_scope is not None):
|
|
raise AssertionError(
|
|
"Re-entered a ParallelDevice scope without first exiting it.")
|
|
self._assert_eager()
|
|
self._device_scope = ops.device(self._name)
|
|
self._saving_scope = saving.independent_buffers(self)
|
|
self._device_scope.__enter__()
|
|
# TODO(allenl): Fixing saving in Python is a bit odd. One alternative would
|
|
# be to provide a hook for the custom device to create save specs/etc., then
|
|
# call that hook from the default variable implementation if the variable is
|
|
# on a custom device. We'll likely want similar hooks for repr() and such.
|
|
self._saving_scope.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, typ, exc, tb):
|
|
self._device_scope.__exit__(typ, exc, tb)
|
|
self._saving_scope.__exit__(typ, exc, tb)
|
|
self._device_scope = None
|
|
self._saving_scope = None
|