diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index 9d787d26433..f4dbcc6cead 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -7,10 +7,26 @@ package( licenses = ["notice"], # Apache 2.0 ) +# Currently pybind extension shared objects must use only C API headers since +# the C API has static initializers duplicated in the Python bindings. So we +# need a second rule that omits .cc files, in +# tensorflow/python:_pywrap_parallel_device. +filegroup( + name = "headers", + srcs = ["parallel_device.h"], + visibility = ["//tensorflow/python:__pkg__"], +) + +filegroup( + name = "sources", + srcs = ["parallel_device.cc"], + visibility = ["//tensorflow/python:__pkg__"], +) + cc_library( name = "parallel_device", - srcs = ["parallel_device.cc"], - hdrs = ["parallel_device.h"], + srcs = [":sources"], + hdrs = [":headers"], deps = [ "//tensorflow/c:c_api", "//tensorflow/c/eager:c_api", diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index bd5d8e777f2..e6846809fcf 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -574,23 +574,21 @@ void DeleteParallelDevice(void* device_info) { } // namespace -void RegisterParallelDevice(TFE_Context* context, const char* device_name, - const char** underlying_devices, - int num_underlying_devices, TF_Status* status) { - TFE_CustomDevice custom_device; - custom_device.copy_tensor_to_device = &CopyToParallelDevice; - custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice; - custom_device.delete_device = &DeleteParallelDevice; - custom_device.execute = &ParallelDeviceExecute; +void AllocateParallelDevice(const char* device_name, + const char* const* underlying_devices, + int num_underlying_devices, + TFE_CustomDevice* device, void** device_info) { + device->copy_tensor_to_device = &CopyToParallelDevice; + device->copy_tensor_from_device = &CopyTensorFromParallelDevice; + device->delete_device = &DeleteParallelDevice; + device->execute = &ParallelDeviceExecute; std::vector<std::string> underlying_devices_vector; underlying_devices_vector.reserve(num_underlying_devices); for (int device_index = 0; device_index < num_underlying_devices; ++device_index) { underlying_devices_vector.push_back(underlying_devices[device_index]); } - ParallelDevice* d = - new ParallelDevice(device_name, underlying_devices_vector); - TFE_RegisterCustomDevice(context, custom_device, device_name, d, status); + *device_info = new ParallelDevice(device_name, underlying_devices_vector); } } // namespace eager diff --git a/tensorflow/c/eager/parallel_device/parallel_device.h b/tensorflow/c/eager/parallel_device/parallel_device.h index b106524401f..f448a4c5b83 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.h +++ b/tensorflow/c/eager/parallel_device/parallel_device.h @@ -16,12 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ +#include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" namespace tensorflow { namespace eager { -// Register a parallel device named `device_name` which forwards operations to +// Allocate a parallel device named `device_name` which forwards operations to // `underlying_devices`, maintaining "parallel tensors" with components placed // on each underlying device. // @@ -50,11 +52,12 @@ namespace eager { // TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor // into its components. // -// `context` owns the parallel device. `underlying_devices` must stay valid -// while the parallel device is in use. -void RegisterParallelDevice(TFE_Context* context, const char* device_name, - const char** underlying_devices, - int num_underlying_devices, TF_Status* status); +// The filled `device` struct and the allocated `device_info` struct may be +// passed to TFE_RegisterCustomDevice. The `device_name` arguments must match. +void AllocateParallelDevice(const char* device_name, + const char* const* underlying_devices, + int num_underlying_devices, + TFE_CustomDevice* device, void** device_info); } // namespace eager } // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc index 41c7d64e231..9b0613b0391 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc @@ -288,6 +288,19 @@ void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) { *static_cast<float*>(TF_TensorData(value_zero.get()))); } +template <std::size_t num_devices> +void RegisterParallelDevice( + TFE_Context* context, const char* device_name, + const std::array<const char*, num_devices>& underlying_devices, + TF_Status* status) { + TFE_CustomDevice device; + void* device_info; + tensorflow::eager::AllocateParallelDevice( + device_name, underlying_devices.data(), underlying_devices.size(), + &device, &device_info); + TFE_RegisterCustomDevice(context, device, device_name, device_info, status); +} + // Create and modify a variable placed on a parallel device which composes // `first_device` and `second_device`. void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, @@ -297,9 +310,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, TF_NewStatus(), TF_DeleteStatus); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; std::array<const char*, 2> underlying_devices{first_device, second_device}; - tensorflow::eager::RegisterParallelDevice( - context, device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + RegisterParallelDevice(context, device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a variable handle (uninitialized to start) placed on the parallel @@ -456,16 +468,14 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector<const char*> underlying_devices; const char* first_device_name = "/job:localhost/replica:0/task:0/device:CPU:0"; - underlying_devices.push_back(first_device_name); const char* second_device_name = "/job:localhost/replica:0/task:0/device:CPU:1"; - underlying_devices.push_back(second_device_name); - tensorflow::eager::RegisterParallelDevice( - context.get(), device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + std::array<const char*, 2> underlying_devices{first_device_name, + second_device_name}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get())); @@ -524,12 +534,11 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector<const char*> underlying_devices; - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1"); - tensorflow::eager::RegisterParallelDevice( - context.get(), device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + std::array<const char*, 2> underlying_devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create two vectors with different lengths @@ -570,24 +579,22 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { // Create a parallel device with two CPUs const char* first_device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector<const char*> first_underlying_devices{ + std::array<const char*, 2> first_underlying_devices{ "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:1"}; - tensorflow::eager::RegisterParallelDevice( - context.get(), first_device_name, first_underlying_devices.data(), - first_underlying_devices.size(), status.get()); + RegisterParallelDevice(context.get(), first_device_name, + first_underlying_devices, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a second parallel device with the first parallel device and one // additional CPU. const char* second_device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:1"; - std::vector<const char*> second_underlying_devices{ + std::array<const char*, 2> second_underlying_devices{ "/job:localhost/replica:0/task:0/device:CUSTOM:0", "/job:localhost/replica:0/task:0/device:CPU:2"}; - tensorflow::eager::RegisterParallelDevice( - context.get(), second_device_name, second_underlying_devices.data(), - second_underlying_devices.size(), status.get()); + RegisterParallelDevice(context.get(), second_device_name, + second_underlying_devices, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a tensor on the first parallel device @@ -656,11 +663,10 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) { std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector<const char*> underlying_devices; - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); - tensorflow::eager::RegisterParallelDevice( - context.get(), device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + std::array<const char*, 1> underlying_devices{ + "/job:localhost/replica:0/task:0/device:CPU:0"}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TensorHandlePtr value_one(FloatTensorHandle(1., status.get())); @@ -775,12 +781,11 @@ TEST(PARALLEL_DEVICE, TestCollective) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector<const char*> underlying_devices; - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1"); - tensorflow::eager::RegisterParallelDevice( - context.get(), device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + std::array<const char*, 2> underlying_devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a tensor on the parallel device @@ -867,12 +872,11 @@ TEST(PARALLEL_DEVICE, TestFunction) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector<const char*> underlying_devices; - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1"); - tensorflow::eager::RegisterParallelDevice( - context.get(), device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + std::array<const char*, 2> underlying_devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const char* function_name = "test_reduce_mul"; diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 27a1fd9645c..ee0ee4dd95d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -8032,6 +8032,29 @@ py_binary( ], ) +tf_python_pybind_extension( + name = "_pywrap_parallel_device", + srcs = [ + "lib/core/safe_ptr.h", + "//tensorflow/c:headers", + "//tensorflow/c/eager:headers", + "//tensorflow/c/eager/parallel_device:headers", + "//tensorflow/c/eager/parallel_device:sources", + "//tensorflow/python/distribute/parallel_device:pywrap_parallel_device.cc", + ], + module_name = "_pywrap_parallel_device", + visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core:protos_all_cc", + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + pyx_library( name = "framework_fast_tensor_util", srcs = ["framework/fast_tensor_util.pyx"], diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD new file mode 100644 index 00000000000..e7526a56f66 --- /dev/null +++ b/tensorflow/python/distribute/parallel_device/BUILD @@ -0,0 +1,45 @@ +package( + licenses = ["notice"], # Apache 2.0 +) + +# Pybind rules must live in tensorflow/python due to header rule visibility. +exports_files( + ["pywrap_parallel_device.cc"], + visibility = ["//tensorflow/python:__pkg__"], +) + +py_library( + name = "parallel_device", + srcs = ["parallel_device.py"], + srcs_version = "PY2AND3", + deps = [ + ":saving", + "//tensorflow/python:_pywrap_parallel_device", + ], +) + +py_library( + name = "saving", + srcs = ["saving.py"], + srcs_version = "PY2AND3", + deps = ["//tensorflow/python:framework_ops"], +) + +py_test( + name = "parallel_device_test", + srcs = ["parallel_device_test.py"], + python_version = "PY3", + tags = [ + # Dependencies aren't otherwise included in the pip package yet. + "no_pip", + ], + deps = [ + ":parallel_device", + ":saving", + "//tensorflow/python:client_testlib", + "//tensorflow/python:collective_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python/module", + "//tensorflow/python/tpu", + ], +) diff --git a/tensorflow/python/distribute/parallel_device/parallel_device.py b/tensorflow/python/distribute/parallel_device/parallel_device.py new file mode 100644 index 00000000000..982b061cdb7 --- /dev/null +++ b/tensorflow/python/distribute/parallel_device/parallel_device.py @@ -0,0 +1,95 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility for eagerly executing operations in parallel on multiple devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import threading + +from tensorflow.python import _pywrap_parallel_device +from tensorflow.python.distribute.parallel_device import saving +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.tpu.ops import tpu_ops + +_next_device_number = 0 +_next_device_number_lock = threading.Lock() + + +# TODO(allenl): Expand this docstring once things like getting components on and +# off the device are stable. +class ParallelDevice(object): + """A device which executes operations in parallel.""" + + def __init__(self, components): + """Creates a device which executes operations in parallel on `components`. + + Args: + components: A list of device names. Each operation executed on the + returned device executes on these component devices. + + Returns: + A string with the name of the newly created device. + """ + global _next_device_number, _next_device_number_lock + self.components = tuple(components) + ctx = context.context() + with _next_device_number_lock: + # TODO(allenl): Better names for parallel devices (right now "CUSTOM" is + # special-cased). + self.name = "{}/device:CUSTOM:{}".format( + ctx.host_address_space(), _next_device_number) + _next_device_number += 1 + device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules( + self.name, self.components) + context.register_custom_device(device, self.name, device_info) + + def pack(self, tensors): + """Create a tensor on the parallel device from a sequence of tensors. + + Args: + tensors: A flat list of tensors, one per device in `self.components`. + + Returns: + A single tensor placed on `self.name`. + """ + with ops.device(self.name): + return tpu_ops.tpu_replicated_input(inputs=tensors) + + def unpack(self, parallel_tensor): + """Unpack a parallel tensor into its components. + + Args: + parallel_tensor: A tensor placed on `self.name`. + + Returns: + A flat list of tensors, one per `self.components`. + """ + with ops.device(self.name): + return tpu_ops.tpu_replicated_output( + parallel_tensor, num_replicas=len(self.components)) + + # TODO(allenl): Fixing saving in Python is a bit odd. One alternative would be + # to provide a hook for the custom device to create save specs/etc., then call + # that hook from the default variable implementation if the variable is on a + # custom device. We'll likely want similar hooks for repr() and such. + @contextlib.contextmanager + def scope(self): + """Runs ops in parallel, makes variables which save independent buffers.""" + with ops.device(self.name), saving.independent_buffers(self): + yield diff --git a/tensorflow/python/distribute/parallel_device/parallel_device_test.py b/tensorflow/python/distribute/parallel_device/parallel_device_test.py new file mode 100644 index 00000000000..d3f3417eca9 --- /dev/null +++ b/tensorflow/python/distribute/parallel_device/parallel_device_test.py @@ -0,0 +1,254 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import threading + +from tensorflow.python.distribute.parallel_device import parallel_device +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.module import module +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training.tracking import util as tracking +from tensorflow.python.util import nest + +# When running collectives asynchronously, we need to give each parallel device +# execution a unique ID so the collectives don't interfere. Since the op is +# replicated with group/instance key intact, the replicated nodes will +# communicate. +# TODO(allenl): Switch to using a collective manager. +_COUNTER_LOCK = threading.Lock() +_COUNTER = 0 + + +def _collective_reduce(inputs, operation, num_replicas): + + def _reduce_tensor(tensor): + with _COUNTER_LOCK: + global _COUNTER + keys = _COUNTER + _COUNTER += 1 + return collective_ops.all_reduce( + t=tensor, + group_size=num_replicas, + merge_op=operation, + group_key=keys, + instance_key=keys, + final_op="Id") + + return nest.map_structure(_reduce_tensor, inputs) + + +def _collective_sum(inputs, num_replicas): + return _collective_reduce( + inputs=inputs, operation="Add", num_replicas=num_replicas) + + +class _Dense(module.Module): + + def __init__(self, output_size): + self.output_size = output_size + self.kernel = None + self.bias = None + + def __call__(self, x): + if self.kernel is None: + self.kernel = variables.Variable( + array_ops.ones( + array_ops.stack([self.output_size, + array_ops.shape(x)[-1]]))) + self.bias = variables.Variable(array_ops.ones([self.output_size])) + return math_ops.matmul(x, self.kernel, transpose_b=True) + self.bias + + +class _VirtualDeviceTestCase(test.TestCase): + + def setUp(self): + super(_VirtualDeviceTestCase, self).setUp() + cpus = context.context().list_physical_devices("CPU") + # Set 4 virtual CPUs + context.context().set_logical_device_configuration(cpus[0], [ + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration() + ]) + + # TODO(allenl): Make CPU:0 and CPU:1 work (right now "CPU:1" soft-places + # onto CPU:0, which seems wrong). + components = [ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1" + ] + self.device = parallel_device.ParallelDevice(components) + + +class ParallelDeviceTests(_VirtualDeviceTestCase): + + def test_register_parallel_device(self): + with ops.device(self.device.name): + c = constant_op.constant(1.) + d = constant_op.constant(2.) + e = c + d + outputs = self.device.unpack(e) + self.assertAllClose([3., 3.], outputs) + + self.assertIn(self.device.components[0], outputs[0].backing_device) + self.assertIn(self.device.components[1], outputs[1].backing_device) + + def test_collective_reduce(self): + with ops.device(self.device.name): + x = self.device.pack( + [constant_op.constant(-1.5), + constant_op.constant(3.5)]) + reduced = _collective_sum(x, num_replicas=2) + outputs = self.device.unpack(reduced) + self.assertAllClose([2., 2.], outputs) + self.assertIn(self.device.components[0], outputs[0].backing_device) + self.assertIn(self.device.components[1], outputs[1].backing_device) + + def test_checkpointing(self): + prefix = os.path.join(self.get_temp_dir(), "ckpt") + with self.device.scope(): + different_values = self.device.pack( + [constant_op.constant(-1.), + constant_op.constant(3.)]) + v = variables.Variable(different_values) + checkpoint = tracking.Checkpoint(v=v) + save_path = checkpoint.save(prefix) + with ops.device(self.device.name): + v.assign(constant_op.constant(0.)) + # Make sure the checkpoint is actually written before we try to read it + context.async_wait() + checkpoint.restore(save_path).assert_consumed() + with ops.device(self.device.name): + outputs = self.device.unpack(v) + self.assertAllClose([-1., 3.], outputs) + + +class LayerTests(_VirtualDeviceTestCase): + + def test_layer_forward(self): + with ops.device(self.device.name): + layer = _Dense(5) + x = constant_op.constant([[2.]]) + y = layer(x) + outputs = self.device.unpack(y) + self.assertAllClose([[3.] * 5], outputs[0]) + self.assertAllClose([[3.] * 5], outputs[1]) + self.assertIn(self.device.components[0], outputs[0].backing_device) + self.assertIn(self.device.components[1], outputs[1].backing_device) + + # With different Layer inputs we get different outputs + with ops.device(self.device.name): + x = self.device.pack( + [constant_op.constant([[-0.5]]), + constant_op.constant([[0.5]])]) + y = layer(x) + outputs = self.device.unpack(y) + self.assertGreater( + math_ops.reduce_max(math_ops.abs(outputs[0] - outputs[1])), 1e-5) + self.assertIn(self.device.components[0], outputs[0].backing_device) + self.assertIn(self.device.components[1], outputs[1].backing_device) + + def test_layer_sync_training(self): + with ops.device(self.device.name): + layer = _Dense(5) + + with backprop.GradientTape() as tape: + x = self.device.pack( + [constant_op.constant([[-0.5]]), + constant_op.constant([[0.5]])]) + y = layer(x) + loss = (y - math_ops.range(5.))**2. + parameters = layer.trainable_variables + unreduced_gradients = tape.gradient(loss, parameters) + reduced_gradients = _collective_sum(unreduced_gradients, num_replicas=2) + for grad, param in zip(reduced_gradients, parameters): + param.assign_sub(0.01 * grad) + final_kernels = self.device.unpack(layer.kernel) + self.assertAllClose(final_kernels[0], final_kernels[1]) + final_bias = self.device.unpack(layer.bias) + expected_bias = (1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)) - + 0.01 * 2. * (1. - .5 - math_ops.range(5.))) + self.assertAllClose(expected_bias, final_bias[0]) + self.assertAllClose(expected_bias, final_bias[1]) + self.assertIn(self.device.components[0], final_kernels[0].backing_device) + self.assertIn(self.device.components[1], final_kernels[1].backing_device) + + def test_layer_divergent_buffer_training(self): + with ops.device(self.device.name): + layer = _Dense(5) + + with backprop.GradientTape() as tape: + x = self.device.pack( + [constant_op.constant([[-0.5]]), + constant_op.constant([[0.5]])]) + y = layer(x) + loss = (y - math_ops.range(5.))**2. + parameters = layer.trainable_variables + unreduced_gradients = tape.gradient(loss, parameters) + for grad, param in zip(unreduced_gradients, parameters): + param.assign_sub(0.01 * grad) + final_kernels = self.device.unpack(layer.kernel) + self.assertNotAllClose(final_kernels[0], final_kernels[1]) + final_bias = self.device.unpack(layer.bias) + self.assertAllClose(1. - 0.01 * 2. * (1. - .5 - math_ops.range(5.)), + final_bias[0]) + self.assertAllClose(1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)), + final_bias[1]) + self.assertIn(self.device.components[0], final_kernels[0].backing_device) + self.assertIn(self.device.components[1], final_kernels[1].backing_device) + + def test_training_loop(self): + for _ in range(5): + layer = _Dense(5) + checkpoint = tracking.Checkpoint(layer=layer) + manager = checkpoint_management.CheckpointManager( + checkpoint, directory=self.get_temp_dir(), max_to_keep=5) + manager.restore_or_initialize() + + for _ in range(10): + with self.device.scope(): + with backprop.GradientTape() as tape: + x = self.device.pack( + [constant_op.constant([[-0.5]]), + constant_op.constant([[0.5]])]) + y = layer(x) + loss = (y - math_ops.range(5.))**2. + parameters = layer.trainable_variables + unreduced_gradients = tape.gradient(loss, parameters) + reduced_gradients = _collective_sum( + unreduced_gradients, num_replicas=len(self.device.components)) + for grad, param in zip(reduced_gradients, parameters): + param.assign_sub(0.01 * grad) + + manager.save() + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc new file mode 100644 index 00000000000..62488cb31e7 --- /dev/null +++ b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "Python.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/parallel_device/parallel_device.h" +#include "tensorflow/python/lib/core/py_exception_registry.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" +#include "tensorflow/python/lib/core/safe_ptr.h" + +namespace py = pybind11; + +void CallDelete_Device(PyObject* capsule) { + delete reinterpret_cast<TFE_CustomDevice*>( + PyCapsule_GetPointer(capsule, "TFE_CustomDevice")); +} + +void CallDelete_DeviceInfo(PyObject* capsule) { + void (*destructor)(void*) = + reinterpret_cast<void (*)(void*)>(PyCapsule_GetContext(capsule)); + destructor(PyCapsule_GetPointer(capsule, "TFE_CustomDevice_DeviceInfo")); +} + +PYBIND11_MODULE(_pywrap_parallel_device, m) { + m.def("GetParallelDeviceCapsules", + [](const char* name, std::vector<std::string> underlying_devices) { + std::vector<const char*> underlying_devices_c; + underlying_devices_c.reserve(underlying_devices.size()); + for (const std::string& element : underlying_devices) { + underlying_devices_c.push_back(element.c_str()); + } + // `device` is owned by `device_capsule`. + TFE_CustomDevice* device = new TFE_CustomDevice; + tensorflow::Safe_PyObjectPtr device_capsule( + PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device)); + void* device_info; + tensorflow::eager::AllocateParallelDevice( + name, underlying_devices_c.data(), underlying_devices_c.size(), + device, &device_info); + if (PyErr_Occurred()) throw py::error_already_set(); + tensorflow::Safe_PyObjectPtr device_info_capsule( + PyCapsule_New(device_info, "TFE_CustomDevice_DeviceInfo", + &CallDelete_DeviceInfo)); + if (PyErr_Occurred()) throw py::error_already_set(); + // The PyCapsule destructor needs a pointer to the destructor for + // DeviceInfo. + PyCapsule_SetContext(device_info_capsule.get(), + reinterpret_cast<void*>(device->delete_device)); + return tensorflow::PyoOrThrow( + PyTuple_Pack(2, device_capsule.get(), device_info_capsule.get())); + }); +} diff --git a/tensorflow/python/distribute/parallel_device/saving.py b/tensorflow/python/distribute/parallel_device/saving.py new file mode 100644 index 00000000000..f2e7dadae41 --- /dev/null +++ b/tensorflow/python/distribute/parallel_device/saving.py @@ -0,0 +1,131 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Special-cased checkpointing for variables on a parallel device.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import functools + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training.saving import saveable_object + + +def _read_component(handle, dtype, replica_id, parallel_device): + """Read one component of a parallel variable and discard the rest.""" + with ops.device(handle.device): + read = gen_resource_variable_ops.read_variable_op( + resource=handle, dtype=dtype) + all_components = parallel_device.unpack(read) + # We're pretending that parallel variables have a first axis with length + # num_components, so we need to add a dummy first axis to the shape that gets + # saved. + return all_components[replica_id][None, ...] + + +class _ParallelDeviceSaveable(saveable_object.SaveableObject): + """Saves and restores a parallel variable.""" + + def __init__(self, name, handle, dtype, component_shape, parallel_device): + # Each component device gets one spec with a tensor to save. + specs = [] + for replica_id, device_name in enumerate(parallel_device.components): + # TODO(b/151773535): SaveableObjects with SaveSpecs on different devices + # will cause extra copying at the moment. We should fix that before doing + # anything serious with this code. + specs.append( + saveable_object.SaveSpec( + tensor=functools.partial( + _read_component, + handle=handle, + dtype=dtype, + replica_id=replica_id, + parallel_device=parallel_device), + slice_spec=variables.Variable.SaveSliceInfo( + full_shape=([len(parallel_device.components)] + + component_shape), + var_offset=[replica_id] + [0] * len(component_shape), + var_shape=[1] + component_shape).spec, + device=device_name, + dtype=dtype, + name=name)) + self._handle = handle + self._parallel_device = parallel_device + self._component_shape = component_shape + super(_ParallelDeviceSaveable, self).__init__(None, specs, name) + + def restore(self, tensors, restored_shapes=None): + with ops.device(self._handle.device): + # Combine the restored tensors into one parallel tensor to assign. + bundled = self._parallel_device.pack(tensors) + gen_resource_variable_ops.assign_variable_op( + resource=self._handle, + # Squeeze out the dummy first axis we added when saving. + value=array_ops.squeeze(bundled, axis=0)) + + +class VariableWithFixedCheckpointing(resource_variable_ops.ResourceVariable): + """Overrides checkpointing behavior to save like a partitioned variable.""" + + def __init__(self, parallel_device, **kwargs): + self._parallel_device = parallel_device + kwargs = {k: v for k, v in kwargs.items() + if k not in ["use_resource", "expected_shape"]} + super(VariableWithFixedCheckpointing, self).__init__(**kwargs) + + def _gather_saveables_for_checkpoint(self): + # Note VARIABLE_VALUE is the usual attribute name for variables. Using + # something different means (a) the checkpointing infrastructure won't try + # doing restore-on-create (which has shape issues), and (b) the saved + # variables won't be compatible with regular variables. Both of those are + # good in this case. + return dict( + PARALLEL_VARIABLE_VALUE=functools.partial( + _ParallelDeviceSaveable, + handle=self.handle, + dtype=self.dtype, + component_shape=self.shape, + parallel_device=self._parallel_device)) + + +def _variable_creator(next_creator, parallel_device, **kwargs): + del next_creator + return VariableWithFixedCheckpointing( + parallel_device=parallel_device, **kwargs) + + +@contextlib.contextmanager +def independent_buffers(parallel_device): + """Context manager which saves parallel buffers independently. + + Creates a ParallelDevice-aware variable subclass which saves buffers for each + device separately. + + Args: + parallel_device: A ParallelDevice object on which variables are placed. + + Yields: + Nothing. + """ + with variable_scope.variable_creator_scope( + functools.partial(_variable_creator, parallel_device=parallel_device)): + yield