Factor out C++ types from the parallel device

PiperOrigin-RevId: 314807016
Change-Id: I4e41ac3e8a08ea0f1db93826652142a083f17fd1
This commit is contained in:
Allen Lavoie 2020-06-04 14:44:21 -07:00 committed by TensorFlower Gardener
parent 8186228713
commit bfc32671fd
7 changed files with 500 additions and 400 deletions

View File

@ -12,28 +12,69 @@ package(
# need a second rule that omits .cc files, in # need a second rule that omits .cc files, in
# tensorflow/python:_pywrap_parallel_device. # tensorflow/python:_pywrap_parallel_device.
filegroup( filegroup(
name = "headers", name = "lib_headers",
srcs = ["parallel_device_lib.h"],
)
filegroup(
name = "lib_sources",
srcs = ["parallel_device_lib.cc"],
)
filegroup(
name = "device_headers",
srcs = ["parallel_device.h"], srcs = ["parallel_device.h"],
)
filegroup(
name = "device_sources",
srcs = ["parallel_device.cc"],
)
filegroup(
name = "headers",
srcs = [
":device_headers",
":lib_headers",
],
visibility = ["//tensorflow/python:__pkg__"], visibility = ["//tensorflow/python:__pkg__"],
) )
filegroup( filegroup(
name = "sources", name = "sources",
srcs = ["parallel_device.cc"], srcs = [
":device_sources",
":lib_sources",
],
visibility = ["//tensorflow/python:__pkg__"], visibility = ["//tensorflow/python:__pkg__"],
) )
cc_library( cc_library(
name = "parallel_device", name = "parallel_device",
srcs = [":sources"], srcs = [":device_sources"],
hdrs = [":headers"], hdrs = [":device_headers"],
visibility = ["//tensorflow:internal"],
deps = [
":parallel_device_lib",
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
cc_library(
name = "parallel_device_lib",
srcs = [":lib_sources"],
hdrs = [":lib_headers"],
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant", "@com_google_absl//absl/types:variant",
], ],

View File

@ -23,25 +23,13 @@ limitations under the License.
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow { namespace tensorflow {
namespace eager { namespace parallel_device {
namespace { namespace {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class OpDeleter { class OpDeleter {
public: public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); } void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
@ -49,224 +37,43 @@ class OpDeleter {
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>; using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorOwned = using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>; absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Creates a vector of `count` new executors (threads). // A ParallelDevice on its own is not registered with a TFE_Context, and so has
std::vector<ExecutorPtr> MakeExecutors(size_t count) { // no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
std::vector<ExecutorPtr> executors; // name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
executors.reserve(count); // placed on the parallel device.
for (int i = 0; i < count; ++i) { class NamedParallelDevice {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
// A representation of the custom device passed in and out of the TFE custom
// device APIs, providing context about the parallel device to
// ParallelDeviceExecute.
class ParallelDevice {
public: public:
ParallelDevice(const std::string& name, NamedParallelDevice(const std::string& name,
const std::vector<std::string>& devices); std::unique_ptr<ParallelDevice> parallel_device)
: device_name_(name), parallel_device_(std::move(parallel_device)) {}
// Helper to copy a tensor handle from another device once for each component const std::string& name() const { return device_name_; }
// of the ParallelDevice. const ParallelDevice& device() const { return *parallel_device_; }
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
//
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
// tensor, but other operations will implicitly broadcast non-parallel input
// tensors across the ParallelDevice's component devices.
//
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
// causes `Execute` to return non-parallel tensors.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK.
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
// Implements the parallel case for `Execute`, where all of the outputs of the
// operation are ParallelTensors, and all inputs are either ParallelTensors or
// should be implicitly broadcast. This means the operation is not
// TPUReplicatedInput or TPUReplicatedOutput.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ExecuteParallelOperation(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name,
const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
const std::string& device_name() const { return device_name_; }
private: private:
// The name of the parallel device std::string device_name_;
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0") std::unique_ptr<ParallelDevice> parallel_device_;
const std::string device_name_;
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
}; };
// The internal representation of a TFE_TensorHandle placed on a absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
// ParallelDevice. Contains a tuple of tensors, one on each of the
// `underlying_devices_` of the ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device, const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status); const std::string& parallel_device_name, TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it. const TFE_OpAttrs* attributes, int expected_max_outputs,
static TensorHandlePtr AsTensorHandle(TFE_Context* context, TF_Status* status) {
std::unique_ptr<ParallelTensor> t,
TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
ParallelDevice::ParallelDevice(const std::string& name,
const std::vector<std::string>& devices)
: device_name_(name),
underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
if (device_name_ == current_device) {
std::string message(absl::StrCat(
"Tried to copy a TensorHandle to its existing device: ", device_name_));
TF_SetStatus(status, TF_INTERNAL, message.c_str());
return nullptr;
}
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<MaybeParallelTensorOwned>> result; absl::optional<std::vector<MaybeParallelTensorOwned>> result;
// TODO(allenl): We should remove "TPU" from these op names at the very least, // TODO(allenl): We should remove "TPU" from these op names at the very least,
// or consider other ways of packing/unpacking parallel tensors. // or consider other ways of packing/unpacking parallel tensors.
if (operation_name == std::string("TPUReplicatedInput")) { if (operation_name == std::string("TPUReplicatedInput")) {
// Special-cased operation for packing per-device tensors into one parallel // Special-cased operation for packing per-device tensors into one parallel
// tensor. // tensor.
if (inputs.size() != underlying_devices_.size()) { if (inputs.size() != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat( std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ", "The parallel device ", parallel_device_name, " expected ",
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ", parallel_device.num_underlying_devices(),
inputs.size())); " inputs to TPUReplicatedInput, but got ", inputs.size()));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str()); TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result; return result;
} }
@ -289,7 +96,7 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
std::vector<MaybeParallelTensorOwned> result_content; std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1); result_content.reserve(1);
result_content.push_back(ParallelTensor::FromTensorHandles( result_content.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status)); parallel_device, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result; if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content)); result.emplace(std::move(result_content));
return result; return result;
@ -300,10 +107,10 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_OpAddAttrs(op.get(), attributes); TFE_OpAddAttrs(op.get(), attributes);
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status); int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
if (TF_GetCode(status) != TF_OK) return result; if (TF_GetCode(status) != TF_OK) return result;
if (expected_outputs != underlying_devices_.size()) { if (expected_outputs != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat( std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ", "The parallel device ", parallel_device_name, " expected ",
underlying_devices_.size(), parallel_device.num_underlying_devices(),
" outputs for TPUReplicatedOutput, but got ", expected_outputs)); " outputs for TPUReplicatedOutput, but got ", expected_outputs));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str()); TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result; return result;
@ -329,14 +136,14 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
} else if (operation_name == std::string("DeviceID")) { } else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content; std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1); result_content.reserve(1);
result_content.push_back(DeviceIDs(context, status)); result_content.push_back(parallel_device.DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result; if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content)); result.emplace(std::move(result_content));
return result; return result;
} }
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results( maybe_parallel_results(
ExecuteParallelOperation(context, std::move(inputs), operation_name, parallel_device.Execute(context, std::move(inputs), operation_name,
attributes, expected_max_outputs, status)); attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result; if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results( std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
@ -351,153 +158,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
return result; return result;
} }
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::ExecuteParallelOperation(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) {
first_op_output_count = real_num_outputs;
} else {
if (real_num_outputs != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// TODO(b/157523095): Syncing the executor here shouldn't be
// necessary. Currently async+remote is missing cross-executor
// coordination.
TFE_ExecutorWaitForAllPendingNodes(executor, status);
if (TF_GetCode(status) != TF_OK) return result;
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how // Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their // ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero. // reference counts drop to zero.
@ -505,17 +165,18 @@ void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<ParallelTensor*>(data); delete reinterpret_cast<ParallelTensor*>(data);
} }
TensorHandlePtr ParallelTensor::AsTensorHandle( TensorHandlePtr ParallelTensorToTensorHandle(
TFE_Context* context, std::unique_ptr<ParallelTensor> t, const std::string& parallel_device_name, TFE_Context* context,
TF_Status* status) { std::unique_ptr<ParallelTensor> t, TF_Status* status) {
// The resulting TensorHandle owns an opaque pointer to "device memory", which // The resulting TensorHandle owns an opaque pointer to "device memory", which
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is // for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct. // deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release(); ParallelTensor* t_released = t.release();
const std::vector<int64_t>& shape(t_released->shape());
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory( return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
context, t_released->device_.device_name().c_str(), t_released->dtype_, context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
t_released->shape_.data(), t_released->shape_.size(), t_released, 1, shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
&ParallelTensorDeallocator, nullptr, status)); status));
} }
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device // For TFE_CustomDevice::copy_tensor_to_device in the parallel device
@ -531,12 +192,14 @@ TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context, TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor, TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) { TF_Status* status, void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info); NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
const ParallelDevice& dev = named_device->device();
std::unique_ptr<ParallelTensor> parallel_tensor( std::unique_ptr<ParallelTensor> parallel_tensor(
dev->CopyToParallelDevice(context, tensor, status)); dev.CopyToParallelDevice(context, tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr; if (TF_GetCode(status) != TF_OK) return nullptr;
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor), return ParallelTensorToTensorHandle(named_device->name(), context,
status) std::move(parallel_tensor), status)
.release(); .release();
} }
@ -570,14 +233,15 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
const TFE_OpAttrs* attributes, int* num_outputs, const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status, TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) { void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info); NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs; std::vector<MaybeParallelTensorUnowned> typed_inputs;
typed_inputs.reserve(num_inputs); typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
const char* tensor_handle_device = const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status); TFE_TensorHandleDeviceName(inputs[i], status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
if (dev->device_name() == tensor_handle_device) { if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are // We assume that any tensors already placed on this device are
// ParallelTensors. // ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>( typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
@ -589,8 +253,9 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
} }
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs( absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
dev->Execute(context, std::move(typed_inputs), operation_name, attributes, ExecuteWithSpecialOps(named_device->device(), named_device->name(),
*num_outputs, status)); context, std::move(typed_inputs), operation_name,
attributes, *num_outputs, status));
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
if (!maybe_typed_outputs.has_value()) { if (!maybe_typed_outputs.has_value()) {
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned."); TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
@ -611,8 +276,8 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) { if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release(); outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
} else { } else {
outputs[i] = ParallelTensor::AsTensorHandle( outputs[i] = ParallelTensorToTensorHandle(
context, named_device->name(), context,
std::move(absl::get<std::unique_ptr<ParallelTensor>>( std::move(absl::get<std::unique_ptr<ParallelTensor>>(
typed_output)), typed_output)),
status) status)
@ -629,7 +294,7 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
// device_info is passed in using a C-style generic. It must always be a // device_info is passed in using a C-style generic. It must always be a
// ParallelDevice. // ParallelDevice.
void DeleteParallelDevice(void* device_info) { void DeleteParallelDevice(void* device_info) {
delete reinterpret_cast<ParallelDevice*>(device_info); delete reinterpret_cast<NamedParallelDevice*>(device_info);
} }
} // namespace } // namespace
@ -648,8 +313,10 @@ void AllocateParallelDevice(const char* device_name,
++device_index) { ++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]); underlying_devices_vector.push_back(underlying_devices[device_index]);
} }
*device_info = new ParallelDevice(device_name, underlying_devices_vector); std::unique_ptr<ParallelDevice> parallel_device(
new ParallelDevice(underlying_devices_vector));
*device_info =
new NamedParallelDevice{device_name, std::move(parallel_device)};
} }
} // namespace parallel_device
} // namespace eager
} // namespace tensorflow } // namespace tensorflow

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow { namespace tensorflow {
namespace eager { namespace parallel_device {
// Allocate a parallel device named `device_name` which forwards operations to // Allocate a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed // `underlying_devices`, maintaining "parallel tensors" with components placed
@ -59,7 +59,7 @@ void AllocateParallelDevice(const char* device_name,
int num_underlying_devices, int num_underlying_devices,
TFE_CustomDevice* device, void** device_info); TFE_CustomDevice* device, void** device_info);
} // namespace eager } // namespace parallel_device
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ #endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_

View File

@ -0,0 +1,251 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace parallel_device {
namespace {
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
};
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
// Creates a vector of `count` new executors (threads).
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
std::vector<ExecutorPtr> executors;
executors.reserve(count);
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
} // namespace
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
: underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor =
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) {
first_op_output_count = real_num_outputs;
} else {
if (real_num_outputs != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// TODO(b/157523095): Syncing the executor here shouldn't be
// necessary. Currently async+remote is missing cross-executor
// coordination.
TFE_ExecutorWaitForAllPendingNodes(executor, status);
if (TF_GetCode(status) != TF_OK) return result;
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -0,0 +1,141 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace parallel_device {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Forwards operations to `devices`, maintaining ParallelTensor with components
// placed on each underlying device.
class ParallelDevice {
public:
explicit ParallelDevice(const std::vector<std::string>& devices);
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// The number of devices operations run on.
size_t num_underlying_devices() const { return underlying_devices_.size(); }
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
private:
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
};
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
// ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
// A generalization of the shapes of the underlying tensors.
const std::vector<int64_t>& shape() const { return shape_; }
TF_DataType dtype() const { return dtype_; }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
} // namespace parallel_device
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_

View File

@ -165,7 +165,7 @@ void RegisterParallelDevice(
TF_Status* status) { TF_Status* status) {
TFE_CustomDevice device; TFE_CustomDevice device;
void* device_info; void* device_info;
tensorflow::eager::AllocateParallelDevice( tensorflow::parallel_device::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(), device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info); &device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status); TFE_RegisterCustomDevice(context, device, device_name, device_info, status);

View File

@ -52,7 +52,7 @@ PYBIND11_MODULE(_pywrap_parallel_device, m) {
tensorflow::Safe_PyObjectPtr device_capsule( tensorflow::Safe_PyObjectPtr device_capsule(
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device)); PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
void* device_info; void* device_info;
tensorflow::eager::AllocateParallelDevice( tensorflow::parallel_device::AllocateParallelDevice(
name, underlying_devices_c.data(), underlying_devices_c.size(), name, underlying_devices_c.data(), underlying_devices_c.size(),
device, &device_info); device, &device_info);
if (PyErr_Occurred()) throw py::error_already_set(); if (PyErr_Occurred()) throw py::error_already_set();