Add Python bindings and some tests for the parallel device
Still quite experimental; no public API and lots of TODOs. Since the parallel device uses the C API, the pybind extension needs to re-include the parallel device sources so it can use the copy of the C API in pywrap_tensorflow. This is pretty ugly, but I don't see a way around it until pywrap_tensorflow relies on libtensorflow.so as the single source of the C API. PiperOrigin-RevId: 307840967 Change-Id: Id8e7e72b14e2e5a2886c6025c7ef6f92f71a156c
This commit is contained in:
parent
65fd3b702b
commit
1e4ccc88c0
tensorflow
@ -7,10 +7,26 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Currently pybind extension shared objects must use only C API headers since
|
||||
# the C API has static initializers duplicated in the Python bindings. So we
|
||||
# need a second rule that omits .cc files, in
|
||||
# tensorflow/python:_pywrap_parallel_device.
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = ["parallel_device.h"],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "sources",
|
||||
srcs = ["parallel_device.cc"],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device",
|
||||
srcs = ["parallel_device.cc"],
|
||||
hdrs = ["parallel_device.h"],
|
||||
srcs = [":sources"],
|
||||
hdrs = [":headers"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
|
@ -574,23 +574,21 @@ void DeleteParallelDevice(void* device_info) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
|
||||
const char** underlying_devices,
|
||||
int num_underlying_devices, TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToParallelDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
|
||||
custom_device.delete_device = &DeleteParallelDevice;
|
||||
custom_device.execute = &ParallelDeviceExecute;
|
||||
void AllocateParallelDevice(const char* device_name,
|
||||
const char* const* underlying_devices,
|
||||
int num_underlying_devices,
|
||||
TFE_CustomDevice* device, void** device_info) {
|
||||
device->copy_tensor_to_device = &CopyToParallelDevice;
|
||||
device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
|
||||
device->delete_device = &DeleteParallelDevice;
|
||||
device->execute = &ParallelDeviceExecute;
|
||||
std::vector<std::string> underlying_devices_vector;
|
||||
underlying_devices_vector.reserve(num_underlying_devices);
|
||||
for (int device_index = 0; device_index < num_underlying_devices;
|
||||
++device_index) {
|
||||
underlying_devices_vector.push_back(underlying_devices[device_index]);
|
||||
}
|
||||
ParallelDevice* d =
|
||||
new ParallelDevice(device_name, underlying_devices_vector);
|
||||
TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
|
||||
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
|
@ -16,12 +16,14 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
// Register a parallel device named `device_name` which forwards operations to
|
||||
// Allocate a parallel device named `device_name` which forwards operations to
|
||||
// `underlying_devices`, maintaining "parallel tensors" with components placed
|
||||
// on each underlying device.
|
||||
//
|
||||
@ -50,11 +52,12 @@ namespace eager {
|
||||
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
|
||||
// into its components.
|
||||
//
|
||||
// `context` owns the parallel device. `underlying_devices` must stay valid
|
||||
// while the parallel device is in use.
|
||||
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
|
||||
const char** underlying_devices,
|
||||
int num_underlying_devices, TF_Status* status);
|
||||
// The filled `device` struct and the allocated `device_info` struct may be
|
||||
// passed to TFE_RegisterCustomDevice. The `device_name` arguments must match.
|
||||
void AllocateParallelDevice(const char* device_name,
|
||||
const char* const* underlying_devices,
|
||||
int num_underlying_devices,
|
||||
TFE_CustomDevice* device, void** device_info);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
@ -288,6 +288,19 @@ void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
void RegisterParallelDevice(
|
||||
TFE_Context* context, const char* device_name,
|
||||
const std::array<const char*, num_devices>& underlying_devices,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice device;
|
||||
void* device_info;
|
||||
tensorflow::eager::AllocateParallelDevice(
|
||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||
&device, &device_info);
|
||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||
}
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
@ -297,9 +310,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context, device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
RegisterParallelDevice(context, device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||
@ -456,16 +468,14 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
underlying_devices.push_back(first_device_name);
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1";
|
||||
underlying_devices.push_back(second_device_name);
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{first_device_name,
|
||||
second_device_name};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
|
||||
@ -524,12 +534,11 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create two vectors with different lengths
|
||||
@ -570,24 +579,22 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
// Create a parallel device with two CPUs
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> first_underlying_devices{
|
||||
std::array<const char*, 2> first_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), first_device_name, first_underlying_devices.data(),
|
||||
first_underlying_devices.size(), status.get());
|
||||
RegisterParallelDevice(context.get(), first_device_name,
|
||||
first_underlying_devices, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a second parallel device with the first parallel device and one
|
||||
// additional CPU.
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
std::vector<const char*> second_underlying_devices{
|
||||
std::array<const char*, 2> second_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:2"};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), second_device_name, second_underlying_devices.data(),
|
||||
second_underlying_devices.size(), status.get());
|
||||
RegisterParallelDevice(context.get(), second_device_name,
|
||||
second_underlying_devices, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the first parallel device
|
||||
@ -656,11 +663,10 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) {
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 1> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
@ -775,12 +781,11 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the parallel device
|
||||
@ -867,12 +872,11 @@ TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* function_name = "test_reduce_mul";
|
||||
|
@ -8032,6 +8032,29 @@ py_binary(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_parallel_device",
|
||||
srcs = [
|
||||
"lib/core/safe_ptr.h",
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c/eager:headers",
|
||||
"//tensorflow/c/eager/parallel_device:headers",
|
||||
"//tensorflow/c/eager/parallel_device:sources",
|
||||
"//tensorflow/python/distribute/parallel_device:pywrap_parallel_device.cc",
|
||||
],
|
||||
module_name = "_pywrap_parallel_device",
|
||||
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:lib_headers_for_pybind",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
pyx_library(
|
||||
name = "framework_fast_tensor_util",
|
||||
srcs = ["framework/fast_tensor_util.pyx"],
|
||||
|
45
tensorflow/python/distribute/parallel_device/BUILD
Normal file
45
tensorflow/python/distribute/parallel_device/BUILD
Normal file
@ -0,0 +1,45 @@
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Pybind rules must live in tensorflow/python due to header rule visibility.
|
||||
exports_files(
|
||||
["pywrap_parallel_device.cc"],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "parallel_device",
|
||||
srcs = ["parallel_device.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":saving",
|
||||
"//tensorflow/python:_pywrap_parallel_device",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saving",
|
||||
srcs = ["saving.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow/python:framework_ops"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "parallel_device_test",
|
||||
srcs = ["parallel_device_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
# Dependencies aren't otherwise included in the pip package yet.
|
||||
"no_pip",
|
||||
],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":saving",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:collective_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/tpu",
|
||||
],
|
||||
)
|
@ -0,0 +1,95 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility for eagerly executing operations in parallel on multiple devices."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
from tensorflow.python import _pywrap_parallel_device
|
||||
from tensorflow.python.distribute.parallel_device import saving
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
|
||||
_next_device_number = 0
|
||||
_next_device_number_lock = threading.Lock()
|
||||
|
||||
|
||||
# TODO(allenl): Expand this docstring once things like getting components on and
|
||||
# off the device are stable.
|
||||
class ParallelDevice(object):
|
||||
"""A device which executes operations in parallel."""
|
||||
|
||||
def __init__(self, components):
|
||||
"""Creates a device which executes operations in parallel on `components`.
|
||||
|
||||
Args:
|
||||
components: A list of device names. Each operation executed on the
|
||||
returned device executes on these component devices.
|
||||
|
||||
Returns:
|
||||
A string with the name of the newly created device.
|
||||
"""
|
||||
global _next_device_number, _next_device_number_lock
|
||||
self.components = tuple(components)
|
||||
ctx = context.context()
|
||||
with _next_device_number_lock:
|
||||
# TODO(allenl): Better names for parallel devices (right now "CUSTOM" is
|
||||
# special-cased).
|
||||
self.name = "{}/device:CUSTOM:{}".format(
|
||||
ctx.host_address_space(), _next_device_number)
|
||||
_next_device_number += 1
|
||||
device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
|
||||
self.name, self.components)
|
||||
context.register_custom_device(device, self.name, device_info)
|
||||
|
||||
def pack(self, tensors):
|
||||
"""Create a tensor on the parallel device from a sequence of tensors.
|
||||
|
||||
Args:
|
||||
tensors: A flat list of tensors, one per device in `self.components`.
|
||||
|
||||
Returns:
|
||||
A single tensor placed on `self.name`.
|
||||
"""
|
||||
with ops.device(self.name):
|
||||
return tpu_ops.tpu_replicated_input(inputs=tensors)
|
||||
|
||||
def unpack(self, parallel_tensor):
|
||||
"""Unpack a parallel tensor into its components.
|
||||
|
||||
Args:
|
||||
parallel_tensor: A tensor placed on `self.name`.
|
||||
|
||||
Returns:
|
||||
A flat list of tensors, one per `self.components`.
|
||||
"""
|
||||
with ops.device(self.name):
|
||||
return tpu_ops.tpu_replicated_output(
|
||||
parallel_tensor, num_replicas=len(self.components))
|
||||
|
||||
# TODO(allenl): Fixing saving in Python is a bit odd. One alternative would be
|
||||
# to provide a hook for the custom device to create save specs/etc., then call
|
||||
# that hook from the default variable implementation if the variable is on a
|
||||
# custom device. We'll likely want similar hooks for repr() and such.
|
||||
@contextlib.contextmanager
|
||||
def scope(self):
|
||||
"""Runs ops in parallel, makes variables which save independent buffers."""
|
||||
with ops.device(self.name), saving.independent_buffers(self):
|
||||
yield
|
@ -0,0 +1,254 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
from tensorflow.python.distribute.parallel_device import parallel_device
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.tracking import util as tracking
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
# When running collectives asynchronously, we need to give each parallel device
|
||||
# execution a unique ID so the collectives don't interfere. Since the op is
|
||||
# replicated with group/instance key intact, the replicated nodes will
|
||||
# communicate.
|
||||
# TODO(allenl): Switch to using a collective manager.
|
||||
_COUNTER_LOCK = threading.Lock()
|
||||
_COUNTER = 0
|
||||
|
||||
|
||||
def _collective_reduce(inputs, operation, num_replicas):
|
||||
|
||||
def _reduce_tensor(tensor):
|
||||
with _COUNTER_LOCK:
|
||||
global _COUNTER
|
||||
keys = _COUNTER
|
||||
_COUNTER += 1
|
||||
return collective_ops.all_reduce(
|
||||
t=tensor,
|
||||
group_size=num_replicas,
|
||||
merge_op=operation,
|
||||
group_key=keys,
|
||||
instance_key=keys,
|
||||
final_op="Id")
|
||||
|
||||
return nest.map_structure(_reduce_tensor, inputs)
|
||||
|
||||
|
||||
def _collective_sum(inputs, num_replicas):
|
||||
return _collective_reduce(
|
||||
inputs=inputs, operation="Add", num_replicas=num_replicas)
|
||||
|
||||
|
||||
class _Dense(module.Module):
|
||||
|
||||
def __init__(self, output_size):
|
||||
self.output_size = output_size
|
||||
self.kernel = None
|
||||
self.bias = None
|
||||
|
||||
def __call__(self, x):
|
||||
if self.kernel is None:
|
||||
self.kernel = variables.Variable(
|
||||
array_ops.ones(
|
||||
array_ops.stack([self.output_size,
|
||||
array_ops.shape(x)[-1]])))
|
||||
self.bias = variables.Variable(array_ops.ones([self.output_size]))
|
||||
return math_ops.matmul(x, self.kernel, transpose_b=True) + self.bias
|
||||
|
||||
|
||||
class _VirtualDeviceTestCase(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(_VirtualDeviceTestCase, self).setUp()
|
||||
cpus = context.context().list_physical_devices("CPU")
|
||||
# Set 4 virtual CPUs
|
||||
context.context().set_logical_device_configuration(cpus[0], [
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration()
|
||||
])
|
||||
|
||||
# TODO(allenl): Make CPU:0 and CPU:1 work (right now "CPU:1" soft-places
|
||||
# onto CPU:0, which seems wrong).
|
||||
components = [
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"
|
||||
]
|
||||
self.device = parallel_device.ParallelDevice(components)
|
||||
|
||||
|
||||
class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
|
||||
def test_register_parallel_device(self):
|
||||
with ops.device(self.device.name):
|
||||
c = constant_op.constant(1.)
|
||||
d = constant_op.constant(2.)
|
||||
e = c + d
|
||||
outputs = self.device.unpack(e)
|
||||
self.assertAllClose([3., 3.], outputs)
|
||||
|
||||
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||
|
||||
def test_collective_reduce(self):
|
||||
with ops.device(self.device.name):
|
||||
x = self.device.pack(
|
||||
[constant_op.constant(-1.5),
|
||||
constant_op.constant(3.5)])
|
||||
reduced = _collective_sum(x, num_replicas=2)
|
||||
outputs = self.device.unpack(reduced)
|
||||
self.assertAllClose([2., 2.], outputs)
|
||||
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||
|
||||
def test_checkpointing(self):
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.device.scope():
|
||||
different_values = self.device.pack(
|
||||
[constant_op.constant(-1.),
|
||||
constant_op.constant(3.)])
|
||||
v = variables.Variable(different_values)
|
||||
checkpoint = tracking.Checkpoint(v=v)
|
||||
save_path = checkpoint.save(prefix)
|
||||
with ops.device(self.device.name):
|
||||
v.assign(constant_op.constant(0.))
|
||||
# Make sure the checkpoint is actually written before we try to read it
|
||||
context.async_wait()
|
||||
checkpoint.restore(save_path).assert_consumed()
|
||||
with ops.device(self.device.name):
|
||||
outputs = self.device.unpack(v)
|
||||
self.assertAllClose([-1., 3.], outputs)
|
||||
|
||||
|
||||
class LayerTests(_VirtualDeviceTestCase):
|
||||
|
||||
def test_layer_forward(self):
|
||||
with ops.device(self.device.name):
|
||||
layer = _Dense(5)
|
||||
x = constant_op.constant([[2.]])
|
||||
y = layer(x)
|
||||
outputs = self.device.unpack(y)
|
||||
self.assertAllClose([[3.] * 5], outputs[0])
|
||||
self.assertAllClose([[3.] * 5], outputs[1])
|
||||
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||
|
||||
# With different Layer inputs we get different outputs
|
||||
with ops.device(self.device.name):
|
||||
x = self.device.pack(
|
||||
[constant_op.constant([[-0.5]]),
|
||||
constant_op.constant([[0.5]])])
|
||||
y = layer(x)
|
||||
outputs = self.device.unpack(y)
|
||||
self.assertGreater(
|
||||
math_ops.reduce_max(math_ops.abs(outputs[0] - outputs[1])), 1e-5)
|
||||
self.assertIn(self.device.components[0], outputs[0].backing_device)
|
||||
self.assertIn(self.device.components[1], outputs[1].backing_device)
|
||||
|
||||
def test_layer_sync_training(self):
|
||||
with ops.device(self.device.name):
|
||||
layer = _Dense(5)
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
x = self.device.pack(
|
||||
[constant_op.constant([[-0.5]]),
|
||||
constant_op.constant([[0.5]])])
|
||||
y = layer(x)
|
||||
loss = (y - math_ops.range(5.))**2.
|
||||
parameters = layer.trainable_variables
|
||||
unreduced_gradients = tape.gradient(loss, parameters)
|
||||
reduced_gradients = _collective_sum(unreduced_gradients, num_replicas=2)
|
||||
for grad, param in zip(reduced_gradients, parameters):
|
||||
param.assign_sub(0.01 * grad)
|
||||
final_kernels = self.device.unpack(layer.kernel)
|
||||
self.assertAllClose(final_kernels[0], final_kernels[1])
|
||||
final_bias = self.device.unpack(layer.bias)
|
||||
expected_bias = (1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)) -
|
||||
0.01 * 2. * (1. - .5 - math_ops.range(5.)))
|
||||
self.assertAllClose(expected_bias, final_bias[0])
|
||||
self.assertAllClose(expected_bias, final_bias[1])
|
||||
self.assertIn(self.device.components[0], final_kernels[0].backing_device)
|
||||
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
|
||||
|
||||
def test_layer_divergent_buffer_training(self):
|
||||
with ops.device(self.device.name):
|
||||
layer = _Dense(5)
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
x = self.device.pack(
|
||||
[constant_op.constant([[-0.5]]),
|
||||
constant_op.constant([[0.5]])])
|
||||
y = layer(x)
|
||||
loss = (y - math_ops.range(5.))**2.
|
||||
parameters = layer.trainable_variables
|
||||
unreduced_gradients = tape.gradient(loss, parameters)
|
||||
for grad, param in zip(unreduced_gradients, parameters):
|
||||
param.assign_sub(0.01 * grad)
|
||||
final_kernels = self.device.unpack(layer.kernel)
|
||||
self.assertNotAllClose(final_kernels[0], final_kernels[1])
|
||||
final_bias = self.device.unpack(layer.bias)
|
||||
self.assertAllClose(1. - 0.01 * 2. * (1. - .5 - math_ops.range(5.)),
|
||||
final_bias[0])
|
||||
self.assertAllClose(1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)),
|
||||
final_bias[1])
|
||||
self.assertIn(self.device.components[0], final_kernels[0].backing_device)
|
||||
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
|
||||
|
||||
def test_training_loop(self):
|
||||
for _ in range(5):
|
||||
layer = _Dense(5)
|
||||
checkpoint = tracking.Checkpoint(layer=layer)
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory=self.get_temp_dir(), max_to_keep=5)
|
||||
manager.restore_or_initialize()
|
||||
|
||||
for _ in range(10):
|
||||
with self.device.scope():
|
||||
with backprop.GradientTape() as tape:
|
||||
x = self.device.pack(
|
||||
[constant_op.constant([[-0.5]]),
|
||||
constant_op.constant([[0.5]])])
|
||||
y = layer(x)
|
||||
loss = (y - math_ops.range(5.))**2.
|
||||
parameters = layer.trainable_variables
|
||||
unreduced_gradients = tape.gradient(loss, parameters)
|
||||
reduced_gradients = _collective_sum(
|
||||
unreduced_gradients, num_replicas=len(self.device.components))
|
||||
for grad, param in zip(reduced_gradients, parameters):
|
||||
param.assign_sub(0.01 * grad)
|
||||
|
||||
manager.save()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
@ -0,0 +1,70 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "Python.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
#include "tensorflow/python/lib/core/py_exception_registry.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void CallDelete_Device(PyObject* capsule) {
|
||||
delete reinterpret_cast<TFE_CustomDevice*>(
|
||||
PyCapsule_GetPointer(capsule, "TFE_CustomDevice"));
|
||||
}
|
||||
|
||||
void CallDelete_DeviceInfo(PyObject* capsule) {
|
||||
void (*destructor)(void*) =
|
||||
reinterpret_cast<void (*)(void*)>(PyCapsule_GetContext(capsule));
|
||||
destructor(PyCapsule_GetPointer(capsule, "TFE_CustomDevice_DeviceInfo"));
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_pywrap_parallel_device, m) {
|
||||
m.def("GetParallelDeviceCapsules",
|
||||
[](const char* name, std::vector<std::string> underlying_devices) {
|
||||
std::vector<const char*> underlying_devices_c;
|
||||
underlying_devices_c.reserve(underlying_devices.size());
|
||||
for (const std::string& element : underlying_devices) {
|
||||
underlying_devices_c.push_back(element.c_str());
|
||||
}
|
||||
// `device` is owned by `device_capsule`.
|
||||
TFE_CustomDevice* device = new TFE_CustomDevice;
|
||||
tensorflow::Safe_PyObjectPtr device_capsule(
|
||||
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
|
||||
void* device_info;
|
||||
tensorflow::eager::AllocateParallelDevice(
|
||||
name, underlying_devices_c.data(), underlying_devices_c.size(),
|
||||
device, &device_info);
|
||||
if (PyErr_Occurred()) throw py::error_already_set();
|
||||
tensorflow::Safe_PyObjectPtr device_info_capsule(
|
||||
PyCapsule_New(device_info, "TFE_CustomDevice_DeviceInfo",
|
||||
&CallDelete_DeviceInfo));
|
||||
if (PyErr_Occurred()) throw py::error_already_set();
|
||||
// The PyCapsule destructor needs a pointer to the destructor for
|
||||
// DeviceInfo.
|
||||
PyCapsule_SetContext(device_info_capsule.get(),
|
||||
reinterpret_cast<void*>(device->delete_device));
|
||||
return tensorflow::PyoOrThrow(
|
||||
PyTuple_Pack(2, device_capsule.get(), device_info_capsule.get()));
|
||||
});
|
||||
}
|
131
tensorflow/python/distribute/parallel_device/saving.py
Normal file
131
tensorflow/python/distribute/parallel_device/saving.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Special-cased checkpointing for variables on a parallel device."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
|
||||
|
||||
def _read_component(handle, dtype, replica_id, parallel_device):
|
||||
"""Read one component of a parallel variable and discard the rest."""
|
||||
with ops.device(handle.device):
|
||||
read = gen_resource_variable_ops.read_variable_op(
|
||||
resource=handle, dtype=dtype)
|
||||
all_components = parallel_device.unpack(read)
|
||||
# We're pretending that parallel variables have a first axis with length
|
||||
# num_components, so we need to add a dummy first axis to the shape that gets
|
||||
# saved.
|
||||
return all_components[replica_id][None, ...]
|
||||
|
||||
|
||||
class _ParallelDeviceSaveable(saveable_object.SaveableObject):
|
||||
"""Saves and restores a parallel variable."""
|
||||
|
||||
def __init__(self, name, handle, dtype, component_shape, parallel_device):
|
||||
# Each component device gets one spec with a tensor to save.
|
||||
specs = []
|
||||
for replica_id, device_name in enumerate(parallel_device.components):
|
||||
# TODO(b/151773535): SaveableObjects with SaveSpecs on different devices
|
||||
# will cause extra copying at the moment. We should fix that before doing
|
||||
# anything serious with this code.
|
||||
specs.append(
|
||||
saveable_object.SaveSpec(
|
||||
tensor=functools.partial(
|
||||
_read_component,
|
||||
handle=handle,
|
||||
dtype=dtype,
|
||||
replica_id=replica_id,
|
||||
parallel_device=parallel_device),
|
||||
slice_spec=variables.Variable.SaveSliceInfo(
|
||||
full_shape=([len(parallel_device.components)] +
|
||||
component_shape),
|
||||
var_offset=[replica_id] + [0] * len(component_shape),
|
||||
var_shape=[1] + component_shape).spec,
|
||||
device=device_name,
|
||||
dtype=dtype,
|
||||
name=name))
|
||||
self._handle = handle
|
||||
self._parallel_device = parallel_device
|
||||
self._component_shape = component_shape
|
||||
super(_ParallelDeviceSaveable, self).__init__(None, specs, name)
|
||||
|
||||
def restore(self, tensors, restored_shapes=None):
|
||||
with ops.device(self._handle.device):
|
||||
# Combine the restored tensors into one parallel tensor to assign.
|
||||
bundled = self._parallel_device.pack(tensors)
|
||||
gen_resource_variable_ops.assign_variable_op(
|
||||
resource=self._handle,
|
||||
# Squeeze out the dummy first axis we added when saving.
|
||||
value=array_ops.squeeze(bundled, axis=0))
|
||||
|
||||
|
||||
class VariableWithFixedCheckpointing(resource_variable_ops.ResourceVariable):
|
||||
"""Overrides checkpointing behavior to save like a partitioned variable."""
|
||||
|
||||
def __init__(self, parallel_device, **kwargs):
|
||||
self._parallel_device = parallel_device
|
||||
kwargs = {k: v for k, v in kwargs.items()
|
||||
if k not in ["use_resource", "expected_shape"]}
|
||||
super(VariableWithFixedCheckpointing, self).__init__(**kwargs)
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
# Note VARIABLE_VALUE is the usual attribute name for variables. Using
|
||||
# something different means (a) the checkpointing infrastructure won't try
|
||||
# doing restore-on-create (which has shape issues), and (b) the saved
|
||||
# variables won't be compatible with regular variables. Both of those are
|
||||
# good in this case.
|
||||
return dict(
|
||||
PARALLEL_VARIABLE_VALUE=functools.partial(
|
||||
_ParallelDeviceSaveable,
|
||||
handle=self.handle,
|
||||
dtype=self.dtype,
|
||||
component_shape=self.shape,
|
||||
parallel_device=self._parallel_device))
|
||||
|
||||
|
||||
def _variable_creator(next_creator, parallel_device, **kwargs):
|
||||
del next_creator
|
||||
return VariableWithFixedCheckpointing(
|
||||
parallel_device=parallel_device, **kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def independent_buffers(parallel_device):
|
||||
"""Context manager which saves parallel buffers independently.
|
||||
|
||||
Creates a ParallelDevice-aware variable subclass which saves buffers for each
|
||||
device separately.
|
||||
|
||||
Args:
|
||||
parallel_device: A ParallelDevice object on which variables are placed.
|
||||
|
||||
Yields:
|
||||
Nothing.
|
||||
"""
|
||||
with variable_scope.variable_creator_scope(
|
||||
functools.partial(_variable_creator, parallel_device=parallel_device)):
|
||||
yield
|
Loading…
Reference in New Issue
Block a user