Add support for a device ID op in parallel_device

The op doesn't really make sense to register kernels for, so I'm not registering it anywhere by default yet; it's currently just registered in the parallel device tests.

PiperOrigin-RevId: 311141160
Change-Id: Iff1839112dac6fe3406e4b31f0e6f7239809a5bb
This commit is contained in:
Allen Lavoie 2020-05-12 09:31:06 -07:00 committed by TensorFlower Gardener
parent adb282e47c
commit 8e3bc844b1
7 changed files with 186 additions and 16 deletions

View File

@ -44,6 +44,7 @@ tf_cc_test(
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
":parallel_device_ops",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
@ -53,3 +54,19 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
# Note: ParallelDevice-specific ops are experimental and not currently linked in
# to TensorFlow by default, just used in a few tests.
filegroup(
name = "parallel_device_ops_srcs",
srcs = ["parallel_device_ops.cc"],
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
)
cc_library(
name = "parallel_device_ops",
srcs = [":parallel_device_ops_srcs"],
visibility = ["//tensorflow:internal"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)

View File

@ -92,6 +92,10 @@ class ParallelDevice {
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
@ -208,6 +212,46 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
@ -282,6 +326,13 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
}
result.emplace(std::move(outputs));
return result;
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
// TODO(allenl): Figure out if we need this op, and if so whether we should move
// it to core TF. Right now the eager C API does some checking of op
// registrations before calling into custom devices, but we may be able to avoid
// that.
REGISTER_OP("DeviceID")
.Output("device_id: int64")
.SetIsStateful()
.SetShapeFn(tensorflow::shape_inference::ScalarShape);

View File

@ -278,14 +278,15 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
}
// Assert that `handle` is equal to `expected_value`.
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(expected_value,
*static_cast<float*>(TF_TensorData(value_zero.get())));
EXPECT_EQ(expected_value,
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
@ -343,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 20.);
AssertScalarFloatEq(components[1].get(), 20.);
ExpectScalarEq<float>(components[0].get(), 20.);
ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@ -373,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 23.);
AssertScalarFloatEq(components[1].get(), 18.);
ExpectScalarEq<float>(components[0].get(), 23.);
ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@ -383,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}
TEST(PARALLEL_DEVICE, TestBasicCPU) {
@ -498,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// The value of the original tensor is replicated on each device.
AssertScalarFloatEq(components[0].get(), 3.);
AssertScalarFloatEq(components[1].get(), 3.);
ExpectScalarEq<float>(components[0].get(), 3.);
ExpectScalarEq<float>(components[1].get(), 3.);
// Verify that the mirrors are placed on the component devices.
std::string first_device =
@ -630,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
&second_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(second_components[1].get(), 9.);
ExpectScalarEq<float>(second_components[1].get(), 9.);
// Verify that the mirrors are placed on the component devices.
std::string first_device = TFE_TensorHandleBackingDeviceName(
@ -644,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
std::array<TensorHandlePtr, 2> first_components;
ExtractPerDeviceValues(context.get(), second_components[0].get(),
&first_components, status.get());
AssertScalarFloatEq(first_components[0].get(), 3.);
AssertScalarFloatEq(first_components[1].get(), 6.);
ExpectScalarEq<float>(first_components[0].get(), 3.);
ExpectScalarEq<float>(first_components[1].get(), 6.);
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
status.get());
@ -806,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 3.);
AssertScalarFloatEq(result_components[1].get(), 3.);
ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.);
}
void RegisterCollectiveMulFunction(TFE_Context* context,
@ -909,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
std::string first_device = TFE_TensorHandleBackingDeviceName(
result_components[0].get(), status.get());

View File

@ -1,3 +1,6 @@
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
@ -14,6 +17,7 @@ py_library(
srcs = ["parallel_device.py"],
srcs_version = "PY2AND3",
deps = [
":parallel_device_ops",
":saving",
"//tensorflow/python:_pywrap_parallel_device",
],
@ -26,6 +30,25 @@ py_library(
deps = ["//tensorflow/python:framework_ops"],
)
tf_gen_op_wrapper_py(
name = "parallel_device_ops_py",
out = "gen_parallel_device_ops.py",
deps = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
)
tf_custom_op_library(
name = "_parallel_device_ops.so",
srcs = ["//tensorflow/c/eager/parallel_device:parallel_device_ops_srcs"],
)
tf_custom_op_py_library(
name = "parallel_device_ops",
dso = [":_parallel_device_ops.so"],
kernels = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
visibility = ["//tensorflow:internal"],
deps = [":parallel_device_ops_py"],
)
py_test(
name = "parallel_device_test",
srcs = ["parallel_device_test.py"],

View File

@ -22,11 +22,17 @@ import contextlib
import threading
from tensorflow.python import _pywrap_parallel_device
from tensorflow.python.distribute.parallel_device import gen_parallel_device_ops
from tensorflow.python.distribute.parallel_device import saving
from tensorflow.python.eager import context
from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
from tensorflow.python.tpu.ops import tpu_ops
load_library.load_op_library(
resource_loader.get_path_to_datafile("_parallel_device_ops.so"))
_next_device_number = 0
_next_device_number_lock = threading.Lock()
@ -58,6 +64,8 @@ class ParallelDevice(object):
device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
self.name, self.components)
context.register_custom_device(device, self.name, device_info)
with ops.device(self.name):
self._device_ids = gen_parallel_device_ops.device_id()
def pack(self, tensors):
"""Create a tensor on the parallel device from a sequence of tensors.
@ -84,6 +92,18 @@ class ParallelDevice(object):
return tpu_ops.tpu_replicated_output(
parallel_tensor, num_replicas=len(self.components))
@property
def device_ids(self):
"""A parallel tensor with scalar integers numbering component devices.
Each device ID is placed on its corresponding device, in the same order as
the `components` constructor argument.
Returns:
A parallel tensor containing 0 on the first device, 1 on the second, etc.
"""
return self._device_ids
# TODO(allenl): Fixing saving in Python is a bit odd. One alternative would be
# to provide a hook for the custom device to create save specs/etc., then call
# that hook from the default variable implementation if the variable is on a

View File

@ -119,6 +119,12 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_device_id(self):
device_ids = self.device.unpack(self.device.device_ids)
self.assertAllClose([0, 1], device_ids)
self.assertIn(self.device.components[0], device_ids[0].backing_device)
self.assertIn(self.device.components[1], device_ids[1].backing_device)
def test_collective_reduce(self):
with ops.device(self.device.name):
x = self.device.pack(