Add Python bindings and some tests for the parallel device

Still quite experimental; no public API and lots of TODOs.

Since the parallel device uses the C API, the pybind extension needs to re-include the parallel device sources so it can use the copy of the C API in pywrap_tensorflow. This is pretty ugly, but I don't see a way around it until pywrap_tensorflow relies on libtensorflow.so as the single source of the C API.

PiperOrigin-RevId: 307840967
Change-Id: Id8e7e72b14e2e5a2886c6025c7ef6f92f71a156c
This commit is contained in:
Allen Lavoie 2020-04-22 10:10:05 -07:00 committed by TensorFlower Gardener
parent 65fd3b702b
commit 1e4ccc88c0
10 changed files with 698 additions and 59 deletions

View File

@ -7,10 +7,26 @@ package(
licenses = ["notice"], # Apache 2.0
)
# Currently pybind extension shared objects must use only C API headers since
# the C API has static initializers duplicated in the Python bindings. So we
# need a second rule that omits .cc files, in
# tensorflow/python:_pywrap_parallel_device.
filegroup(
name = "headers",
srcs = ["parallel_device.h"],
visibility = ["//tensorflow/python:__pkg__"],
)
filegroup(
name = "sources",
srcs = ["parallel_device.cc"],
visibility = ["//tensorflow/python:__pkg__"],
)
cc_library(
name = "parallel_device",
srcs = ["parallel_device.cc"],
hdrs = ["parallel_device.h"],
srcs = [":sources"],
hdrs = [":headers"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",

View File

@ -574,23 +574,21 @@ void DeleteParallelDevice(void* device_info) {
} // namespace
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
const char** underlying_devices,
int num_underlying_devices, TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToParallelDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
custom_device.delete_device = &DeleteParallelDevice;
custom_device.execute = &ParallelDeviceExecute;
void AllocateParallelDevice(const char* device_name,
const char* const* underlying_devices,
int num_underlying_devices,
TFE_CustomDevice* device, void** device_info) {
device->copy_tensor_to_device = &CopyToParallelDevice;
device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
device->delete_device = &DeleteParallelDevice;
device->execute = &ParallelDeviceExecute;
std::vector<std::string> underlying_devices_vector;
underlying_devices_vector.reserve(num_underlying_devices);
for (int device_index = 0; device_index < num_underlying_devices;
++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]);
}
ParallelDevice* d =
new ParallelDevice(device_name, underlying_devices_vector);
TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
}
} // namespace eager

View File

@ -16,12 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace eager {
// Register a parallel device named `device_name` which forwards operations to
// Allocate a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed
// on each underlying device.
//
@ -50,11 +52,12 @@ namespace eager {
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
// into its components.
//
// `context` owns the parallel device. `underlying_devices` must stay valid
// while the parallel device is in use.
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
const char** underlying_devices,
int num_underlying_devices, TF_Status* status);
// The filled `device` struct and the allocated `device_info` struct may be
// passed to TFE_RegisterCustomDevice. The `device_name` arguments must match.
void AllocateParallelDevice(const char* device_name,
const char* const* underlying_devices,
int num_underlying_devices,
TFE_CustomDevice* device, void** device_info);
} // namespace eager
} // namespace tensorflow

View File

@ -288,6 +288,19 @@ void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
*static_cast<float*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
tensorflow::eager::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
@ -297,9 +310,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
tensorflow::eager::RegisterParallelDevice(
context, device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
RegisterParallelDevice(context, device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
@ -456,16 +468,14 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
const char* first_device_name =
"/job:localhost/replica:0/task:0/device:CPU:0";
underlying_devices.push_back(first_device_name);
const char* second_device_name =
"/job:localhost/replica:0/task:0/device:CPU:1";
underlying_devices.push_back(second_device_name);
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
std::array<const char*, 2> underlying_devices{first_device_name,
second_device_name};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
@ -524,12 +534,11 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
std::array<const char*, 2> underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create two vectors with different lengths
@ -570,24 +579,22 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
// Create a parallel device with two CPUs
const char* first_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> first_underlying_devices{
std::array<const char*, 2> first_underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
tensorflow::eager::RegisterParallelDevice(
context.get(), first_device_name, first_underlying_devices.data(),
first_underlying_devices.size(), status.get());
RegisterParallelDevice(context.get(), first_device_name,
first_underlying_devices, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a second parallel device with the first parallel device and one
// additional CPU.
const char* second_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
std::vector<const char*> second_underlying_devices{
std::array<const char*, 2> second_underlying_devices{
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
"/job:localhost/replica:0/task:0/device:CPU:2"};
tensorflow::eager::RegisterParallelDevice(
context.get(), second_device_name, second_underlying_devices.data(),
second_underlying_devices.size(), status.get());
RegisterParallelDevice(context.get(), second_device_name,
second_underlying_devices, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a tensor on the first parallel device
@ -656,11 +663,10 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) {
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
std::array<const char*, 1> underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0"};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
@ -775,12 +781,11 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
std::array<const char*, 2> underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a tensor on the parallel device
@ -867,12 +872,11 @@ TEST(PARALLEL_DEVICE, TestFunction) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
std::array<const char*, 2> underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* function_name = "test_reduce_mul";

View File

@ -8032,6 +8032,29 @@ py_binary(
],
)
tf_python_pybind_extension(
name = "_pywrap_parallel_device",
srcs = [
"lib/core/safe_ptr.h",
"//tensorflow/c:headers",
"//tensorflow/c/eager:headers",
"//tensorflow/c/eager/parallel_device:headers",
"//tensorflow/c/eager/parallel_device:sources",
"//tensorflow/python/distribute/parallel_device:pywrap_parallel_device.cc",
],
module_name = "_pywrap_parallel_device",
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_headers_for_pybind",
"//tensorflow/core:protos_all_cc",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
pyx_library(
name = "framework_fast_tensor_util",
srcs = ["framework/fast_tensor_util.pyx"],

View File

@ -0,0 +1,45 @@
package(
licenses = ["notice"], # Apache 2.0
)
# Pybind rules must live in tensorflow/python due to header rule visibility.
exports_files(
["pywrap_parallel_device.cc"],
visibility = ["//tensorflow/python:__pkg__"],
)
py_library(
name = "parallel_device",
srcs = ["parallel_device.py"],
srcs_version = "PY2AND3",
deps = [
":saving",
"//tensorflow/python:_pywrap_parallel_device",
],
)
py_library(
name = "saving",
srcs = ["saving.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow/python:framework_ops"],
)
py_test(
name = "parallel_device_test",
srcs = ["parallel_device_test.py"],
python_version = "PY3",
tags = [
# Dependencies aren't otherwise included in the pip package yet.
"no_pip",
],
deps = [
":parallel_device",
":saving",
"//tensorflow/python:client_testlib",
"//tensorflow/python:collective_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python/module",
"//tensorflow/python/tpu",
],
)

View File

@ -0,0 +1,95 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility for eagerly executing operations in parallel on multiple devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import threading
from tensorflow.python import _pywrap_parallel_device
from tensorflow.python.distribute.parallel_device import saving
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.tpu.ops import tpu_ops
_next_device_number = 0
_next_device_number_lock = threading.Lock()
# TODO(allenl): Expand this docstring once things like getting components on and
# off the device are stable.
class ParallelDevice(object):
"""A device which executes operations in parallel."""
def __init__(self, components):
"""Creates a device which executes operations in parallel on `components`.
Args:
components: A list of device names. Each operation executed on the
returned device executes on these component devices.
Returns:
A string with the name of the newly created device.
"""
global _next_device_number, _next_device_number_lock
self.components = tuple(components)
ctx = context.context()
with _next_device_number_lock:
# TODO(allenl): Better names for parallel devices (right now "CUSTOM" is
# special-cased).
self.name = "{}/device:CUSTOM:{}".format(
ctx.host_address_space(), _next_device_number)
_next_device_number += 1
device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
self.name, self.components)
context.register_custom_device(device, self.name, device_info)
def pack(self, tensors):
"""Create a tensor on the parallel device from a sequence of tensors.
Args:
tensors: A flat list of tensors, one per device in `self.components`.
Returns:
A single tensor placed on `self.name`.
"""
with ops.device(self.name):
return tpu_ops.tpu_replicated_input(inputs=tensors)
def unpack(self, parallel_tensor):
"""Unpack a parallel tensor into its components.
Args:
parallel_tensor: A tensor placed on `self.name`.
Returns:
A flat list of tensors, one per `self.components`.
"""
with ops.device(self.name):
return tpu_ops.tpu_replicated_output(
parallel_tensor, num_replicas=len(self.components))
# TODO(allenl): Fixing saving in Python is a bit odd. One alternative would be
# to provide a hook for the custom device to create save specs/etc., then call
# that hook from the default variable implementation if the variable is on a
# custom device. We'll likely want similar hooks for repr() and such.
@contextlib.contextmanager
def scope(self):
"""Runs ops in parallel, makes variables which save independent buffers."""
with ops.device(self.name), saving.independent_buffers(self):
yield

View File

@ -0,0 +1,254 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import threading
from tensorflow.python.distribute.parallel_device import parallel_device
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.tracking import util as tracking
from tensorflow.python.util import nest
# When running collectives asynchronously, we need to give each parallel device
# execution a unique ID so the collectives don't interfere. Since the op is
# replicated with group/instance key intact, the replicated nodes will
# communicate.
# TODO(allenl): Switch to using a collective manager.
_COUNTER_LOCK = threading.Lock()
_COUNTER = 0
def _collective_reduce(inputs, operation, num_replicas):
def _reduce_tensor(tensor):
with _COUNTER_LOCK:
global _COUNTER
keys = _COUNTER
_COUNTER += 1
return collective_ops.all_reduce(
t=tensor,
group_size=num_replicas,
merge_op=operation,
group_key=keys,
instance_key=keys,
final_op="Id")
return nest.map_structure(_reduce_tensor, inputs)
def _collective_sum(inputs, num_replicas):
return _collective_reduce(
inputs=inputs, operation="Add", num_replicas=num_replicas)
class _Dense(module.Module):
def __init__(self, output_size):
self.output_size = output_size
self.kernel = None
self.bias = None
def __call__(self, x):
if self.kernel is None:
self.kernel = variables.Variable(
array_ops.ones(
array_ops.stack([self.output_size,
array_ops.shape(x)[-1]])))
self.bias = variables.Variable(array_ops.ones([self.output_size]))
return math_ops.matmul(x, self.kernel, transpose_b=True) + self.bias
class _VirtualDeviceTestCase(test.TestCase):
def setUp(self):
super(_VirtualDeviceTestCase, self).setUp()
cpus = context.context().list_physical_devices("CPU")
# Set 4 virtual CPUs
context.context().set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
# TODO(allenl): Make CPU:0 and CPU:1 work (right now "CPU:1" soft-places
# onto CPU:0, which seems wrong).
components = [
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"
]
self.device = parallel_device.ParallelDevice(components)
class ParallelDeviceTests(_VirtualDeviceTestCase):
def test_register_parallel_device(self):
with ops.device(self.device.name):
c = constant_op.constant(1.)
d = constant_op.constant(2.)
e = c + d
outputs = self.device.unpack(e)
self.assertAllClose([3., 3.], outputs)
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_collective_reduce(self):
with ops.device(self.device.name):
x = self.device.pack(
[constant_op.constant(-1.5),
constant_op.constant(3.5)])
reduced = _collective_sum(x, num_replicas=2)
outputs = self.device.unpack(reduced)
self.assertAllClose([2., 2.], outputs)
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_checkpointing(self):
prefix = os.path.join(self.get_temp_dir(), "ckpt")
with self.device.scope():
different_values = self.device.pack(
[constant_op.constant(-1.),
constant_op.constant(3.)])
v = variables.Variable(different_values)
checkpoint = tracking.Checkpoint(v=v)
save_path = checkpoint.save(prefix)
with ops.device(self.device.name):
v.assign(constant_op.constant(0.))
# Make sure the checkpoint is actually written before we try to read it
context.async_wait()
checkpoint.restore(save_path).assert_consumed()
with ops.device(self.device.name):
outputs = self.device.unpack(v)
self.assertAllClose([-1., 3.], outputs)
class LayerTests(_VirtualDeviceTestCase):
def test_layer_forward(self):
with ops.device(self.device.name):
layer = _Dense(5)
x = constant_op.constant([[2.]])
y = layer(x)
outputs = self.device.unpack(y)
self.assertAllClose([[3.] * 5], outputs[0])
self.assertAllClose([[3.] * 5], outputs[1])
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
# With different Layer inputs we get different outputs
with ops.device(self.device.name):
x = self.device.pack(
[constant_op.constant([[-0.5]]),
constant_op.constant([[0.5]])])
y = layer(x)
outputs = self.device.unpack(y)
self.assertGreater(
math_ops.reduce_max(math_ops.abs(outputs[0] - outputs[1])), 1e-5)
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_layer_sync_training(self):
with ops.device(self.device.name):
layer = _Dense(5)
with backprop.GradientTape() as tape:
x = self.device.pack(
[constant_op.constant([[-0.5]]),
constant_op.constant([[0.5]])])
y = layer(x)
loss = (y - math_ops.range(5.))**2.
parameters = layer.trainable_variables
unreduced_gradients = tape.gradient(loss, parameters)
reduced_gradients = _collective_sum(unreduced_gradients, num_replicas=2)
for grad, param in zip(reduced_gradients, parameters):
param.assign_sub(0.01 * grad)
final_kernels = self.device.unpack(layer.kernel)
self.assertAllClose(final_kernels[0], final_kernels[1])
final_bias = self.device.unpack(layer.bias)
expected_bias = (1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)) -
0.01 * 2. * (1. - .5 - math_ops.range(5.)))
self.assertAllClose(expected_bias, final_bias[0])
self.assertAllClose(expected_bias, final_bias[1])
self.assertIn(self.device.components[0], final_kernels[0].backing_device)
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
def test_layer_divergent_buffer_training(self):
with ops.device(self.device.name):
layer = _Dense(5)
with backprop.GradientTape() as tape:
x = self.device.pack(
[constant_op.constant([[-0.5]]),
constant_op.constant([[0.5]])])
y = layer(x)
loss = (y - math_ops.range(5.))**2.
parameters = layer.trainable_variables
unreduced_gradients = tape.gradient(loss, parameters)
for grad, param in zip(unreduced_gradients, parameters):
param.assign_sub(0.01 * grad)
final_kernels = self.device.unpack(layer.kernel)
self.assertNotAllClose(final_kernels[0], final_kernels[1])
final_bias = self.device.unpack(layer.bias)
self.assertAllClose(1. - 0.01 * 2. * (1. - .5 - math_ops.range(5.)),
final_bias[0])
self.assertAllClose(1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)),
final_bias[1])
self.assertIn(self.device.components[0], final_kernels[0].backing_device)
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
def test_training_loop(self):
for _ in range(5):
layer = _Dense(5)
checkpoint = tracking.Checkpoint(layer=layer)
manager = checkpoint_management.CheckpointManager(
checkpoint, directory=self.get_temp_dir(), max_to_keep=5)
manager.restore_or_initialize()
for _ in range(10):
with self.device.scope():
with backprop.GradientTape() as tape:
x = self.device.pack(
[constant_op.constant([[-0.5]]),
constant_op.constant([[0.5]])])
y = layer(x)
loss = (y - math_ops.range(5.))**2.
parameters = layer.trainable_variables
unreduced_gradients = tape.gradient(loss, parameters)
reduced_gradients = _collective_sum(
unreduced_gradients, num_replicas=len(self.device.components))
for grad, param in zip(reduced_gradients, parameters):
param.assign_sub(0.01 * grad)
manager.save()
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()

View File

@ -0,0 +1,70 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "Python.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
namespace py = pybind11;
void CallDelete_Device(PyObject* capsule) {
delete reinterpret_cast<TFE_CustomDevice*>(
PyCapsule_GetPointer(capsule, "TFE_CustomDevice"));
}
void CallDelete_DeviceInfo(PyObject* capsule) {
void (*destructor)(void*) =
reinterpret_cast<void (*)(void*)>(PyCapsule_GetContext(capsule));
destructor(PyCapsule_GetPointer(capsule, "TFE_CustomDevice_DeviceInfo"));
}
PYBIND11_MODULE(_pywrap_parallel_device, m) {
m.def("GetParallelDeviceCapsules",
[](const char* name, std::vector<std::string> underlying_devices) {
std::vector<const char*> underlying_devices_c;
underlying_devices_c.reserve(underlying_devices.size());
for (const std::string& element : underlying_devices) {
underlying_devices_c.push_back(element.c_str());
}
// `device` is owned by `device_capsule`.
TFE_CustomDevice* device = new TFE_CustomDevice;
tensorflow::Safe_PyObjectPtr device_capsule(
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
void* device_info;
tensorflow::eager::AllocateParallelDevice(
name, underlying_devices_c.data(), underlying_devices_c.size(),
device, &device_info);
if (PyErr_Occurred()) throw py::error_already_set();
tensorflow::Safe_PyObjectPtr device_info_capsule(
PyCapsule_New(device_info, "TFE_CustomDevice_DeviceInfo",
&CallDelete_DeviceInfo));
if (PyErr_Occurred()) throw py::error_already_set();
// The PyCapsule destructor needs a pointer to the destructor for
// DeviceInfo.
PyCapsule_SetContext(device_info_capsule.get(),
reinterpret_cast<void*>(device->delete_device));
return tensorflow::PyoOrThrow(
PyTuple_Pack(2, device_capsule.get(), device_info_capsule.get()));
});
}

View File

@ -0,0 +1,131 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Special-cased checkpointing for variables on a parallel device."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import functools
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training.saving import saveable_object
def _read_component(handle, dtype, replica_id, parallel_device):
"""Read one component of a parallel variable and discard the rest."""
with ops.device(handle.device):
read = gen_resource_variable_ops.read_variable_op(
resource=handle, dtype=dtype)
all_components = parallel_device.unpack(read)
# We're pretending that parallel variables have a first axis with length
# num_components, so we need to add a dummy first axis to the shape that gets
# saved.
return all_components[replica_id][None, ...]
class _ParallelDeviceSaveable(saveable_object.SaveableObject):
"""Saves and restores a parallel variable."""
def __init__(self, name, handle, dtype, component_shape, parallel_device):
# Each component device gets one spec with a tensor to save.
specs = []
for replica_id, device_name in enumerate(parallel_device.components):
# TODO(b/151773535): SaveableObjects with SaveSpecs on different devices
# will cause extra copying at the moment. We should fix that before doing
# anything serious with this code.
specs.append(
saveable_object.SaveSpec(
tensor=functools.partial(
_read_component,
handle=handle,
dtype=dtype,
replica_id=replica_id,
parallel_device=parallel_device),
slice_spec=variables.Variable.SaveSliceInfo(
full_shape=([len(parallel_device.components)] +
component_shape),
var_offset=[replica_id] + [0] * len(component_shape),
var_shape=[1] + component_shape).spec,
device=device_name,
dtype=dtype,
name=name))
self._handle = handle
self._parallel_device = parallel_device
self._component_shape = component_shape
super(_ParallelDeviceSaveable, self).__init__(None, specs, name)
def restore(self, tensors, restored_shapes=None):
with ops.device(self._handle.device):
# Combine the restored tensors into one parallel tensor to assign.
bundled = self._parallel_device.pack(tensors)
gen_resource_variable_ops.assign_variable_op(
resource=self._handle,
# Squeeze out the dummy first axis we added when saving.
value=array_ops.squeeze(bundled, axis=0))
class VariableWithFixedCheckpointing(resource_variable_ops.ResourceVariable):
"""Overrides checkpointing behavior to save like a partitioned variable."""
def __init__(self, parallel_device, **kwargs):
self._parallel_device = parallel_device
kwargs = {k: v for k, v in kwargs.items()
if k not in ["use_resource", "expected_shape"]}
super(VariableWithFixedCheckpointing, self).__init__(**kwargs)
def _gather_saveables_for_checkpoint(self):
# Note VARIABLE_VALUE is the usual attribute name for variables. Using
# something different means (a) the checkpointing infrastructure won't try
# doing restore-on-create (which has shape issues), and (b) the saved
# variables won't be compatible with regular variables. Both of those are
# good in this case.
return dict(
PARALLEL_VARIABLE_VALUE=functools.partial(
_ParallelDeviceSaveable,
handle=self.handle,
dtype=self.dtype,
component_shape=self.shape,
parallel_device=self._parallel_device))
def _variable_creator(next_creator, parallel_device, **kwargs):
del next_creator
return VariableWithFixedCheckpointing(
parallel_device=parallel_device, **kwargs)
@contextlib.contextmanager
def independent_buffers(parallel_device):
"""Context manager which saves parallel buffers independently.
Creates a ParallelDevice-aware variable subclass which saves buffers for each
device separately.
Args:
parallel_device: A ParallelDevice object on which variables are placed.
Yields:
Nothing.
"""
with variable_scope.variable_creator_scope(
functools.partial(_variable_creator, parallel_device=parallel_device)):
yield