Factor out C++ types from the parallel device
PiperOrigin-RevId: 314807016 Change-Id: I4e41ac3e8a08ea0f1db93826652142a083f17fd1
This commit is contained in:
parent
8186228713
commit
bfc32671fd
|
@ -12,28 +12,69 @@ package(
|
|||
# need a second rule that omits .cc files, in
|
||||
# tensorflow/python:_pywrap_parallel_device.
|
||||
filegroup(
|
||||
name = "headers",
|
||||
name = "lib_headers",
|
||||
srcs = ["parallel_device_lib.h"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "lib_sources",
|
||||
srcs = ["parallel_device_lib.cc"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "device_headers",
|
||||
srcs = ["parallel_device.h"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "device_sources",
|
||||
srcs = ["parallel_device.cc"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = [
|
||||
":device_headers",
|
||||
":lib_headers",
|
||||
],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "sources",
|
||||
srcs = ["parallel_device.cc"],
|
||||
srcs = [
|
||||
":device_sources",
|
||||
":lib_sources",
|
||||
],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device",
|
||||
srcs = [":sources"],
|
||||
hdrs = [":headers"],
|
||||
srcs = [":device_sources"],
|
||||
hdrs = [":device_headers"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":parallel_device_lib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_lib",
|
||||
srcs = [":lib_sources"],
|
||||
hdrs = [":lib_headers"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
|
|
|
@ -23,25 +23,13 @@ limitations under the License.
|
|||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
namespace parallel_device {
|
||||
namespace {
|
||||
|
||||
// Functor for making unique_ptrs slightly more ergonomic. Using
|
||||
// decltype(delete_fn) in the unique_ptr's second template argument requires
|
||||
// passing a function pointer to delete_fn when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) const {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
class OpDeleter {
|
||||
public:
|
||||
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
|
||||
|
@ -49,224 +37,43 @@ class OpDeleter {
|
|||
|
||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
using MaybeParallelTensorOwned =
|
||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// Creates a vector of `count` new executors (threads).
|
||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||
std::vector<ExecutorPtr> executors;
|
||||
executors.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||
}
|
||||
return executors;
|
||||
}
|
||||
|
||||
// A representation of the custom device passed in and out of the TFE custom
|
||||
// device APIs, providing context about the parallel device to
|
||||
// ParallelDeviceExecute.
|
||||
class ParallelDevice {
|
||||
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
|
||||
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
|
||||
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
|
||||
// placed on the parallel device.
|
||||
class NamedParallelDevice {
|
||||
public:
|
||||
ParallelDevice(const std::string& name,
|
||||
const std::vector<std::string>& devices);
|
||||
|
||||
// Helper to copy a tensor handle from another device once for each component
|
||||
// of the ParallelDevice.
|
||||
//
|
||||
// Sets a bad status and returns a nullptr if `tensor` is already on the
|
||||
// ParallelDevice, or if the individual copies fail.
|
||||
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// A parallel tensor with scalar integers numbering component devices.
|
||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||
TF_Status* status) const;
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||
// output of the original operation.
|
||||
//
|
||||
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
|
||||
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
|
||||
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
|
||||
// tensor, but other operations will implicitly broadcast non-parallel input
|
||||
// tensors across the ParallelDevice's component devices.
|
||||
//
|
||||
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
|
||||
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
|
||||
// causes `Execute` to return non-parallel tensors.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK.
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
// Implements the parallel case for `Execute`, where all of the outputs of the
|
||||
// operation are ParallelTensors, and all inputs are either ParallelTensors or
|
||||
// should be implicitly broadcast. This means the operation is not
|
||||
// TPUReplicatedInput or TPUReplicatedOutput.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ExecuteParallelOperation(TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
const std::string& device_name() const { return device_name_; }
|
||||
NamedParallelDevice(const std::string& name,
|
||||
std::unique_ptr<ParallelDevice> parallel_device)
|
||||
: device_name_(name), parallel_device_(std::move(parallel_device)) {}
|
||||
const std::string& name() const { return device_name_; }
|
||||
const ParallelDevice& device() const { return *parallel_device_; }
|
||||
|
||||
private:
|
||||
// The name of the parallel device
|
||||
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
|
||||
const std::string device_name_;
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
const std::vector<std::string> underlying_devices_;
|
||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||
// parallel.
|
||||
const std::vector<ExecutorPtr> executors_;
|
||||
std::string device_name_;
|
||||
std::unique_ptr<ParallelDevice> parallel_device_;
|
||||
};
|
||||
|
||||
// The internal representation of a TFE_TensorHandle placed on a
|
||||
// ParallelDevice. Contains a tuple of tensors, one on each of the
|
||||
// `underlying_devices_` of the ParallelDevice.
|
||||
class ParallelTensor {
|
||||
public:
|
||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||
// devices of a ParallelDevice.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||
|
||||
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
|
||||
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
|
||||
std::unique_ptr<ParallelTensor> t,
|
||||
TF_Status* status);
|
||||
|
||||
size_t num_tensors() const { return tensors_.size(); }
|
||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||
|
||||
private:
|
||||
ParallelTensor(const ParallelDevice& device,
|
||||
std::vector<TensorHandlePtr> tensors,
|
||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||
: device_(device),
|
||||
tensors_(std::move(tensors)),
|
||||
shape_(std::move(shape)),
|
||||
dtype_(dtype) {}
|
||||
|
||||
const ParallelDevice& device_;
|
||||
const std::vector<TensorHandlePtr> tensors_;
|
||||
const std::vector<int64_t> shape_;
|
||||
const TF_DataType dtype_;
|
||||
};
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::string& name,
|
||||
const std::vector<std::string>& devices)
|
||||
: device_name_(name),
|
||||
underlying_devices_(devices),
|
||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
|
||||
if (device_name_ == current_device) {
|
||||
std::string message(absl::StrCat(
|
||||
"Tried to copy a TensorHandle to its existing device: ", device_name_));
|
||||
TF_SetStatus(status, TF_INTERNAL, message.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (const std::string& underlying_device_name : underlying_devices_) {
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, underlying_device_name.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(t);
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
TFE_Context* context, TF_Status* status) const {
|
||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
int64_t* device_id = new int64_t;
|
||||
*device_id = device_index;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int64_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int64_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||
// device names repeatedly.
|
||||
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(device_handle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const {
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
|
||||
const ParallelDevice& parallel_device,
|
||||
const std::string& parallel_device_name, TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) {
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
|
||||
// TODO(allenl): We should remove "TPU" from these op names at the very least,
|
||||
// or consider other ways of packing/unpacking parallel tensors.
|
||||
if (operation_name == std::string("TPUReplicatedInput")) {
|
||||
// Special-cased operation for packing per-device tensors into one parallel
|
||||
// tensor.
|
||||
if (inputs.size() != underlying_devices_.size()) {
|
||||
if (inputs.size() != parallel_device.num_underlying_devices()) {
|
||||
std::string message(absl::StrCat(
|
||||
"The parallel device ", device_name_, " expected ",
|
||||
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
|
||||
inputs.size()));
|
||||
"The parallel device ", parallel_device_name, " expected ",
|
||||
parallel_device.num_underlying_devices(),
|
||||
" inputs to TPUReplicatedInput, but got ", inputs.size()));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
|
@ -289,7 +96,7 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
|||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
parallel_device, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
|
@ -300,10 +107,10 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
|||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
if (expected_outputs != underlying_devices_.size()) {
|
||||
if (expected_outputs != parallel_device.num_underlying_devices()) {
|
||||
std::string message(absl::StrCat(
|
||||
"The parallel device ", device_name_, " expected ",
|
||||
underlying_devices_.size(),
|
||||
"The parallel device ", parallel_device_name, " expected ",
|
||||
parallel_device.num_underlying_devices(),
|
||||
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
|
@ -329,15 +136,15 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
|||
} else if (operation_name == std::string("DeviceID")) {
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(DeviceIDs(context, status));
|
||||
result_content.push_back(parallel_device.DeviceIDs(context, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
ExecuteParallelOperation(context, std::move(inputs), operation_name,
|
||||
attributes, expected_max_outputs, status));
|
||||
parallel_device.Execute(context, std::move(inputs), operation_name,
|
||||
attributes, expected_max_outputs, status));
|
||||
if (!maybe_parallel_results.has_value()) return result;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||
std::move(maybe_parallel_results.value()));
|
||||
|
@ -351,153 +158,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
|||
return result;
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::ExecuteParallelOperation(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const {
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||
// setting the thread-local executor like this.
|
||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
|
||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||
TFE_DeleteExecutor(previous_executor);
|
||||
});
|
||||
int first_op_output_count;
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||
// the value before this function ran.
|
||||
TFE_ContextSetExecutorForThread(context, executor);
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
} else {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<ParallelTensor*>(inputs[input_index])
|
||||
->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
// For nested devices, the inner device sees the async executor we've
|
||||
// set. Inner parallel devices will just overwrite this with their own and
|
||||
// then set it back to ours before returning. This means parallel devices
|
||||
// which consist of several aliased parallel devices would hypothetically
|
||||
// deadlock if the outer parallel device ran one collective with a group
|
||||
// size equal to the total number of aliased physical devices. Currently
|
||||
// physical devices cannot participate in a single collective reduction
|
||||
// multiple times, so this would fail earlier.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||
// rather than a single list of executors so aliased nested parallel devices
|
||||
// don't re-use an executor.
|
||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = real_num_outputs;
|
||||
} else {
|
||||
if (real_num_outputs != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
std::vector<TensorHandlePtr> this_outputs;
|
||||
this_outputs.reserve(real_num_outputs);
|
||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||
this_outputs.emplace_back(op_outputs[output_num]);
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||
// necessary. Currently async+remote is missing cross-executor
|
||||
// coordination.
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||
per_device_outputs.reserve(first_op_output_count);
|
||||
for (int i = 0; i < first_op_output_count; ++i) {
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||
}
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(per_device_outputs));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||
std::vector<int64_t> shape(
|
||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (tensor_dim != shape[i]) {
|
||||
// TODO(allenl): Allow shapes to differ.
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
"Components of a ParallelTensor must currently all have "
|
||||
"the same shape");
|
||||
return nullptr;
|
||||
}
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||
parallel_device, std::move(components), std::move(shape), dtype));
|
||||
}
|
||||
|
||||
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
|
||||
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
|
||||
// reference counts drop to zero.
|
||||
|
@ -505,17 +165,18 @@ void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
|
|||
delete reinterpret_cast<ParallelTensor*>(data);
|
||||
}
|
||||
|
||||
TensorHandlePtr ParallelTensor::AsTensorHandle(
|
||||
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
|
||||
TF_Status* status) {
|
||||
TensorHandlePtr ParallelTensorToTensorHandle(
|
||||
const std::string& parallel_device_name, TFE_Context* context,
|
||||
std::unique_ptr<ParallelTensor> t, TF_Status* status) {
|
||||
// The resulting TensorHandle owns an opaque pointer to "device memory", which
|
||||
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
|
||||
// deleted, it will call ParallelTensorDeallocator to free the struct.
|
||||
ParallelTensor* t_released = t.release();
|
||||
const std::vector<int64_t>& shape(t_released->shape());
|
||||
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, t_released->device_.device_name().c_str(), t_released->dtype_,
|
||||
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
|
||||
&ParallelTensorDeallocator, nullptr, status));
|
||||
context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
|
||||
shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
|
||||
status));
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
|
||||
|
@ -531,12 +192,14 @@ TensorHandlePtr ParallelTensor::AsTensorHandle(
|
|||
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
||||
NamedParallelDevice* named_device =
|
||||
reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||
const ParallelDevice& dev = named_device->device();
|
||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||
dev->CopyToParallelDevice(context, tensor, status));
|
||||
dev.CopyToParallelDevice(context, tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
|
||||
status)
|
||||
return ParallelTensorToTensorHandle(named_device->name(), context,
|
||||
std::move(parallel_tensor), status)
|
||||
.release();
|
||||
}
|
||||
|
||||
|
@ -570,14 +233,15 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
|||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* status,
|
||||
void* device_info) {
|
||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
||||
NamedParallelDevice* named_device =
|
||||
reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||
std::vector<MaybeParallelTensorUnowned> typed_inputs;
|
||||
typed_inputs.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const char* tensor_handle_device =
|
||||
TFE_TensorHandleDeviceName(inputs[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (dev->device_name() == tensor_handle_device) {
|
||||
if (named_device->name() == tensor_handle_device) {
|
||||
// We assume that any tensors already placed on this device are
|
||||
// ParallelTensors.
|
||||
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
|
||||
|
@ -589,8 +253,9 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
|||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
|
||||
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
|
||||
*num_outputs, status));
|
||||
ExecuteWithSpecialOps(named_device->device(), named_device->name(),
|
||||
context, std::move(typed_inputs), operation_name,
|
||||
attributes, *num_outputs, status));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!maybe_typed_outputs.has_value()) {
|
||||
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
|
||||
|
@ -611,8 +276,8 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
|||
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
|
||||
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
|
||||
} else {
|
||||
outputs[i] = ParallelTensor::AsTensorHandle(
|
||||
context,
|
||||
outputs[i] = ParallelTensorToTensorHandle(
|
||||
named_device->name(), context,
|
||||
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
|
||||
typed_output)),
|
||||
status)
|
||||
|
@ -629,7 +294,7 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
|||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void DeleteParallelDevice(void* device_info) {
|
||||
delete reinterpret_cast<ParallelDevice*>(device_info);
|
||||
delete reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -648,8 +313,10 @@ void AllocateParallelDevice(const char* device_name,
|
|||
++device_index) {
|
||||
underlying_devices_vector.push_back(underlying_devices[device_index]);
|
||||
}
|
||||
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
|
||||
std::unique_ptr<ParallelDevice> parallel_device(
|
||||
new ParallelDevice(underlying_devices_vector));
|
||||
*device_info =
|
||||
new NamedParallelDevice{device_name, std::move(parallel_device)};
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
namespace parallel_device {
|
||||
|
||||
// Allocate a parallel device named `device_name` which forwards operations to
|
||||
// `underlying_devices`, maintaining "parallel tensors" with components placed
|
||||
|
@ -59,7 +59,7 @@ void AllocateParallelDevice(const char* device_name,
|
|||
int num_underlying_devices,
|
||||
TFE_CustomDevice* device, void** device_info);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
namespace {
|
||||
|
||||
class OpDeleter {
|
||||
public:
|
||||
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
|
||||
};
|
||||
|
||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
|
||||
// Creates a vector of `count` new executors (threads).
|
||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||
std::vector<ExecutorPtr> executors;
|
||||
executors.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||
}
|
||||
return executors;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
|
||||
: underlying_devices_(devices),
|
||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (const std::string& underlying_device_name : underlying_devices_) {
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, underlying_device_name.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(t);
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
TFE_Context* context, TF_Status* status) const {
|
||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
int64_t* device_id = new int64_t;
|
||||
*device_id = device_index;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int64_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int64_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||
// device names repeatedly.
|
||||
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(device_handle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::Execute(TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) const {
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||
// setting the thread-local executor like this.
|
||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||
auto reset_executor =
|
||||
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
|
||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||
TFE_DeleteExecutor(previous_executor);
|
||||
});
|
||||
int first_op_output_count;
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||
// the value before this function ran.
|
||||
TFE_ContextSetExecutorForThread(context, executor);
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
} else {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<ParallelTensor*>(inputs[input_index])
|
||||
->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
// For nested devices, the inner device sees the async executor we've
|
||||
// set. Inner parallel devices will just overwrite this with their own and
|
||||
// then set it back to ours before returning. This means parallel devices
|
||||
// which consist of several aliased parallel devices would hypothetically
|
||||
// deadlock if the outer parallel device ran one collective with a group
|
||||
// size equal to the total number of aliased physical devices. Currently
|
||||
// physical devices cannot participate in a single collective reduction
|
||||
// multiple times, so this would fail earlier.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||
// rather than a single list of executors so aliased nested parallel devices
|
||||
// don't re-use an executor.
|
||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = real_num_outputs;
|
||||
} else {
|
||||
if (real_num_outputs != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
std::vector<TensorHandlePtr> this_outputs;
|
||||
this_outputs.reserve(real_num_outputs);
|
||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||
this_outputs.emplace_back(op_outputs[output_num]);
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||
// necessary. Currently async+remote is missing cross-executor
|
||||
// coordination.
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||
per_device_outputs.reserve(first_op_output_count);
|
||||
for (int i = 0; i < first_op_output_count; ++i) {
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||
}
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(per_device_outputs));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||
std::vector<int64_t> shape(
|
||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (tensor_dim != shape[i]) {
|
||||
// TODO(allenl): Allow shapes to differ.
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
"Components of a ParallelTensor must currently all have "
|
||||
"the same shape");
|
||||
return nullptr;
|
||||
}
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||
parallel_device, std::move(components), std::move(shape), dtype));
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
|
@ -0,0 +1,141 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
|
||||
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
|
||||
// Functor for making unique_ptrs slightly more ergonomic. Using
|
||||
// decltype(delete_fn) in the unique_ptr's second template argument requires
|
||||
// passing a function pointer to delete_fn when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) const {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// Forwards operations to `devices`, maintaining ParallelTensor with components
|
||||
// placed on each underlying device.
|
||||
class ParallelDevice {
|
||||
public:
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices);
|
||||
|
||||
// Helper to copy a tensor handle from another device once for each component
|
||||
// of the ParallelDevice.
|
||||
//
|
||||
// Sets a bad status and returns a nullptr if `tensor` is already on the
|
||||
// ParallelDevice, or if the individual copies fail.
|
||||
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// A parallel tensor with scalar integers numbering component devices.
|
||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||
TF_Status* status) const;
|
||||
|
||||
// The number of devices operations run on.
|
||||
size_t num_underlying_devices() const { return underlying_devices_.size(); }
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||
// output of the original operation.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
private:
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
const std::vector<std::string> underlying_devices_;
|
||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||
// parallel.
|
||||
const std::vector<ExecutorPtr> executors_;
|
||||
};
|
||||
|
||||
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
|
||||
// ParallelDevice.
|
||||
class ParallelTensor {
|
||||
public:
|
||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||
// devices of a ParallelDevice.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||
|
||||
size_t num_tensors() const { return tensors_.size(); }
|
||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||
|
||||
// A generalization of the shapes of the underlying tensors.
|
||||
const std::vector<int64_t>& shape() const { return shape_; }
|
||||
TF_DataType dtype() const { return dtype_; }
|
||||
|
||||
private:
|
||||
ParallelTensor(const ParallelDevice& device,
|
||||
std::vector<TensorHandlePtr> tensors,
|
||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||
: device_(device),
|
||||
tensors_(std::move(tensors)),
|
||||
shape_(std::move(shape)),
|
||||
dtype_(dtype) {}
|
||||
|
||||
const ParallelDevice& device_;
|
||||
const std::vector<TensorHandlePtr> tensors_;
|
||||
const std::vector<int64_t> shape_;
|
||||
const TF_DataType dtype_;
|
||||
};
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
|
|
@ -165,7 +165,7 @@ void RegisterParallelDevice(
|
|||
TF_Status* status) {
|
||||
TFE_CustomDevice device;
|
||||
void* device_info;
|
||||
tensorflow::eager::AllocateParallelDevice(
|
||||
tensorflow::parallel_device::AllocateParallelDevice(
|
||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||
&device, &device_info);
|
||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||
|
|
|
@ -52,7 +52,7 @@ PYBIND11_MODULE(_pywrap_parallel_device, m) {
|
|||
tensorflow::Safe_PyObjectPtr device_capsule(
|
||||
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
|
||||
void* device_info;
|
||||
tensorflow::eager::AllocateParallelDevice(
|
||||
tensorflow::parallel_device::AllocateParallelDevice(
|
||||
name, underlying_devices_c.data(), underlying_devices_c.size(),
|
||||
device, &device_info);
|
||||
if (PyErr_Occurred()) throw py::error_already_set();
|
||||
|
|
Loading…
Reference in New Issue