Factor out C++ types from the parallel device
PiperOrigin-RevId: 314807016 Change-Id: I4e41ac3e8a08ea0f1db93826652142a083f17fd1
This commit is contained in:
parent
8186228713
commit
bfc32671fd
|
@ -12,28 +12,69 @@ package(
|
||||||
# need a second rule that omits .cc files, in
|
# need a second rule that omits .cc files, in
|
||||||
# tensorflow/python:_pywrap_parallel_device.
|
# tensorflow/python:_pywrap_parallel_device.
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "headers",
|
name = "lib_headers",
|
||||||
|
srcs = ["parallel_device_lib.h"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "lib_sources",
|
||||||
|
srcs = ["parallel_device_lib.cc"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "device_headers",
|
||||||
srcs = ["parallel_device.h"],
|
srcs = ["parallel_device.h"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "device_sources",
|
||||||
|
srcs = ["parallel_device.cc"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "headers",
|
||||||
|
srcs = [
|
||||||
|
":device_headers",
|
||||||
|
":lib_headers",
|
||||||
|
],
|
||||||
visibility = ["//tensorflow/python:__pkg__"],
|
visibility = ["//tensorflow/python:__pkg__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "sources",
|
name = "sources",
|
||||||
srcs = ["parallel_device.cc"],
|
srcs = [
|
||||||
|
":device_sources",
|
||||||
|
":lib_sources",
|
||||||
|
],
|
||||||
visibility = ["//tensorflow/python:__pkg__"],
|
visibility = ["//tensorflow/python:__pkg__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "parallel_device",
|
name = "parallel_device",
|
||||||
srcs = [":sources"],
|
srcs = [":device_sources"],
|
||||||
hdrs = [":headers"],
|
hdrs = [":device_headers"],
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [
|
||||||
|
":parallel_device_lib",
|
||||||
|
"//tensorflow/c:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "parallel_device_lib",
|
||||||
|
srcs = [":lib_sources"],
|
||||||
|
hdrs = [":lib_headers"],
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
"//tensorflow/c/eager:c_api_experimental",
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:variant",
|
"@com_google_absl//absl/types:variant",
|
||||||
],
|
],
|
||||||
|
|
|
@ -23,25 +23,13 @@ limitations under the License.
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace eager {
|
namespace parallel_device {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Functor for making unique_ptrs slightly more ergonomic. Using
|
|
||||||
// decltype(delete_fn) in the unique_ptr's second template argument requires
|
|
||||||
// passing a function pointer to delete_fn when constructing the unique_ptr.
|
|
||||||
class TensorHandleDeleter {
|
|
||||||
public:
|
|
||||||
void operator()(TFE_TensorHandle* to_delete) const {
|
|
||||||
TFE_DeleteTensorHandle(to_delete);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
|
||||||
|
|
||||||
class OpDeleter {
|
class OpDeleter {
|
||||||
public:
|
public:
|
||||||
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
|
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
|
||||||
|
@ -49,224 +37,43 @@ class OpDeleter {
|
||||||
|
|
||||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||||
|
|
||||||
class ExecutorDeleter {
|
|
||||||
public:
|
|
||||||
void operator()(TFE_Executor* to_delete) const {
|
|
||||||
TFE_DeleteExecutor(to_delete);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
|
||||||
|
|
||||||
class ParallelTensor;
|
|
||||||
|
|
||||||
using MaybeParallelTensorOwned =
|
using MaybeParallelTensorOwned =
|
||||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||||
using MaybeParallelTensorUnowned =
|
|
||||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
|
||||||
|
|
||||||
// Creates a vector of `count` new executors (threads).
|
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
|
||||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
|
||||||
std::vector<ExecutorPtr> executors;
|
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
|
||||||
executors.reserve(count);
|
// placed on the parallel device.
|
||||||
for (int i = 0; i < count; ++i) {
|
class NamedParallelDevice {
|
||||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
|
||||||
}
|
|
||||||
return executors;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A representation of the custom device passed in and out of the TFE custom
|
|
||||||
// device APIs, providing context about the parallel device to
|
|
||||||
// ParallelDeviceExecute.
|
|
||||||
class ParallelDevice {
|
|
||||||
public:
|
public:
|
||||||
ParallelDevice(const std::string& name,
|
NamedParallelDevice(const std::string& name,
|
||||||
const std::vector<std::string>& devices);
|
std::unique_ptr<ParallelDevice> parallel_device)
|
||||||
|
: device_name_(name), parallel_device_(std::move(parallel_device)) {}
|
||||||
// Helper to copy a tensor handle from another device once for each component
|
const std::string& name() const { return device_name_; }
|
||||||
// of the ParallelDevice.
|
const ParallelDevice& device() const { return *parallel_device_; }
|
||||||
//
|
|
||||||
// Sets a bad status and returns a nullptr if `tensor` is already on the
|
|
||||||
// ParallelDevice, or if the individual copies fail.
|
|
||||||
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
|
|
||||||
TFE_TensorHandle* tensor,
|
|
||||||
TF_Status* status) const;
|
|
||||||
|
|
||||||
// A parallel tensor with scalar integers numbering component devices.
|
|
||||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
|
||||||
TF_Status* status) const;
|
|
||||||
|
|
||||||
// Takes a description of a single operation being executed on the
|
|
||||||
// ParallelDevice, and in turn runs one operation per component device with
|
|
||||||
// its corresponding inputs from the input ParallelTensors (or
|
|
||||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
|
||||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
|
||||||
// output of the original operation.
|
|
||||||
//
|
|
||||||
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
|
|
||||||
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
|
|
||||||
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
|
|
||||||
// tensor, but other operations will implicitly broadcast non-parallel input
|
|
||||||
// tensors across the ParallelDevice's component devices.
|
|
||||||
//
|
|
||||||
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
|
|
||||||
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
|
|
||||||
// causes `Execute` to return non-parallel tensors.
|
|
||||||
//
|
|
||||||
// Attributes are forwarded to executed operations unmodified.
|
|
||||||
//
|
|
||||||
// The returned optional has a value if and only if `status` evaluates to
|
|
||||||
// TF_OK.
|
|
||||||
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
|
|
||||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
|
||||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
|
||||||
int expected_max_outputs, TF_Status* status) const;
|
|
||||||
|
|
||||||
// Implements the parallel case for `Execute`, where all of the outputs of the
|
|
||||||
// operation are ParallelTensors, and all inputs are either ParallelTensors or
|
|
||||||
// should be implicitly broadcast. This means the operation is not
|
|
||||||
// TPUReplicatedInput or TPUReplicatedOutput.
|
|
||||||
//
|
|
||||||
// The returned optional has a value if and only if `status` evaluates to
|
|
||||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
|
||||||
// if sanity checks on dtypes/metadata fail.
|
|
||||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
|
||||||
ExecuteParallelOperation(TFE_Context* context,
|
|
||||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
|
||||||
const char* operation_name,
|
|
||||||
const TFE_OpAttrs* attributes,
|
|
||||||
int expected_max_outputs, TF_Status* status) const;
|
|
||||||
|
|
||||||
const std::string& device_name() const { return device_name_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The name of the parallel device
|
std::string device_name_;
|
||||||
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
|
std::unique_ptr<ParallelDevice> parallel_device_;
|
||||||
const std::string device_name_;
|
|
||||||
// A sequence of device names, indicating which devices replicated operations
|
|
||||||
// are forwarded to.
|
|
||||||
const std::vector<std::string> underlying_devices_;
|
|
||||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
|
||||||
// parallel.
|
|
||||||
const std::vector<ExecutorPtr> executors_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// The internal representation of a TFE_TensorHandle placed on a
|
absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
|
||||||
// ParallelDevice. Contains a tuple of tensors, one on each of the
|
|
||||||
// `underlying_devices_` of the ParallelDevice.
|
|
||||||
class ParallelTensor {
|
|
||||||
public:
|
|
||||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
|
||||||
// devices of a ParallelDevice.
|
|
||||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
|
||||||
const ParallelDevice& parallel_device,
|
const ParallelDevice& parallel_device,
|
||||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
const std::string& parallel_device_name, TFE_Context* context,
|
||||||
|
std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
|
||||||
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
|
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||||
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
|
TF_Status* status) {
|
||||||
std::unique_ptr<ParallelTensor> t,
|
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
size_t num_tensors() const { return tensors_.size(); }
|
|
||||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
ParallelTensor(const ParallelDevice& device,
|
|
||||||
std::vector<TensorHandlePtr> tensors,
|
|
||||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
|
||||||
: device_(device),
|
|
||||||
tensors_(std::move(tensors)),
|
|
||||||
shape_(std::move(shape)),
|
|
||||||
dtype_(dtype) {}
|
|
||||||
|
|
||||||
const ParallelDevice& device_;
|
|
||||||
const std::vector<TensorHandlePtr> tensors_;
|
|
||||||
const std::vector<int64_t> shape_;
|
|
||||||
const TF_DataType dtype_;
|
|
||||||
};
|
|
||||||
|
|
||||||
ParallelDevice::ParallelDevice(const std::string& name,
|
|
||||||
const std::vector<std::string>& devices)
|
|
||||||
: device_name_(name),
|
|
||||||
underlying_devices_(devices),
|
|
||||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
|
||||||
|
|
||||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
|
||||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
|
||||||
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
|
|
||||||
if (device_name_ == current_device) {
|
|
||||||
std::string message(absl::StrCat(
|
|
||||||
"Tried to copy a TensorHandle to its existing device: ", device_name_));
|
|
||||||
TF_SetStatus(status, TF_INTERNAL, message.c_str());
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
std::vector<TensorHandlePtr> components;
|
|
||||||
components.reserve(underlying_devices_.size());
|
|
||||||
for (const std::string& underlying_device_name : underlying_devices_) {
|
|
||||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
|
||||||
tensor, context, underlying_device_name.c_str(), status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
components.emplace_back(t);
|
|
||||||
}
|
|
||||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
|
||||||
status);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
|
||||||
TFE_Context* context, TF_Status* status) const {
|
|
||||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
|
||||||
std::vector<TensorHandlePtr> components;
|
|
||||||
components.reserve(underlying_devices_.size());
|
|
||||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
|
||||||
++device_index) {
|
|
||||||
int64_t* device_id = new int64_t;
|
|
||||||
*device_id = device_index;
|
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
|
||||||
TF_NewTensor(
|
|
||||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
|
||||||
sizeof(int64_t),
|
|
||||||
[](void* data, size_t, void* arg) {
|
|
||||||
delete reinterpret_cast<int64_t*>(data);
|
|
||||||
},
|
|
||||||
nullptr),
|
|
||||||
TF_DeleteTensor);
|
|
||||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
|
||||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
|
||||||
// device names repeatedly.
|
|
||||||
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
|
||||||
status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
|
||||||
TFE_TensorHandle* device_handle;
|
|
||||||
int num_outputs = 1;
|
|
||||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
components.emplace_back(device_handle);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
}
|
|
||||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
|
||||||
status);
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
|
||||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
|
||||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
|
||||||
int expected_max_outputs, TF_Status* status) const {
|
|
||||||
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
|
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
|
||||||
// TODO(allenl): We should remove "TPU" from these op names at the very least,
|
// TODO(allenl): We should remove "TPU" from these op names at the very least,
|
||||||
// or consider other ways of packing/unpacking parallel tensors.
|
// or consider other ways of packing/unpacking parallel tensors.
|
||||||
if (operation_name == std::string("TPUReplicatedInput")) {
|
if (operation_name == std::string("TPUReplicatedInput")) {
|
||||||
// Special-cased operation for packing per-device tensors into one parallel
|
// Special-cased operation for packing per-device tensors into one parallel
|
||||||
// tensor.
|
// tensor.
|
||||||
if (inputs.size() != underlying_devices_.size()) {
|
if (inputs.size() != parallel_device.num_underlying_devices()) {
|
||||||
std::string message(absl::StrCat(
|
std::string message(absl::StrCat(
|
||||||
"The parallel device ", device_name_, " expected ",
|
"The parallel device ", parallel_device_name, " expected ",
|
||||||
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
|
parallel_device.num_underlying_devices(),
|
||||||
inputs.size()));
|
" inputs to TPUReplicatedInput, but got ", inputs.size()));
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -289,7 +96,7 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||||
std::vector<MaybeParallelTensorOwned> result_content;
|
std::vector<MaybeParallelTensorOwned> result_content;
|
||||||
result_content.reserve(1);
|
result_content.reserve(1);
|
||||||
result_content.push_back(ParallelTensor::FromTensorHandles(
|
result_content.push_back(ParallelTensor::FromTensorHandles(
|
||||||
*this, std::move(components), status));
|
parallel_device, std::move(components), status));
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
result.emplace(std::move(result_content));
|
result.emplace(std::move(result_content));
|
||||||
return result;
|
return result;
|
||||||
|
@ -300,10 +107,10 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||||
TFE_OpAddAttrs(op.get(), attributes);
|
TFE_OpAddAttrs(op.get(), attributes);
|
||||||
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
|
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
if (expected_outputs != underlying_devices_.size()) {
|
if (expected_outputs != parallel_device.num_underlying_devices()) {
|
||||||
std::string message(absl::StrCat(
|
std::string message(absl::StrCat(
|
||||||
"The parallel device ", device_name_, " expected ",
|
"The parallel device ", parallel_device_name, " expected ",
|
||||||
underlying_devices_.size(),
|
parallel_device.num_underlying_devices(),
|
||||||
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
|
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||||
return result;
|
return result;
|
||||||
|
@ -329,14 +136,14 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||||
} else if (operation_name == std::string("DeviceID")) {
|
} else if (operation_name == std::string("DeviceID")) {
|
||||||
std::vector<MaybeParallelTensorOwned> result_content;
|
std::vector<MaybeParallelTensorOwned> result_content;
|
||||||
result_content.reserve(1);
|
result_content.reserve(1);
|
||||||
result_content.push_back(DeviceIDs(context, status));
|
result_content.push_back(parallel_device.DeviceIDs(context, status));
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
result.emplace(std::move(result_content));
|
result.emplace(std::move(result_content));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||||
maybe_parallel_results(
|
maybe_parallel_results(
|
||||||
ExecuteParallelOperation(context, std::move(inputs), operation_name,
|
parallel_device.Execute(context, std::move(inputs), operation_name,
|
||||||
attributes, expected_max_outputs, status));
|
attributes, expected_max_outputs, status));
|
||||||
if (!maybe_parallel_results.has_value()) return result;
|
if (!maybe_parallel_results.has_value()) return result;
|
||||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||||
|
@ -351,153 +158,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
|
||||||
ParallelDevice::ExecuteParallelOperation(
|
|
||||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
|
||||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
|
||||||
int expected_max_outputs, TF_Status* status) const {
|
|
||||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
|
||||||
// Compute per-device per-output tensors
|
|
||||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
|
||||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
|
||||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
|
||||||
// setting the thread-local executor like this.
|
|
||||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
|
||||||
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
|
|
||||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
|
||||||
TFE_DeleteExecutor(previous_executor);
|
|
||||||
});
|
|
||||||
int first_op_output_count;
|
|
||||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
|
||||||
++device_index) {
|
|
||||||
TFE_Executor* executor = executors_[device_index].get();
|
|
||||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
|
||||||
// the value before this function ran.
|
|
||||||
TFE_ContextSetExecutorForThread(context, executor);
|
|
||||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
|
||||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
|
||||||
status);
|
|
||||||
TFE_OpAddAttrs(op.get(), attributes);
|
|
||||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
|
||||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
|
||||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
|
||||||
// to each parallel operation.
|
|
||||||
//
|
|
||||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
|
||||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
|
||||||
// about things that are taken as inputs on the host or on their
|
|
||||||
// existing device (for multi-device functions).
|
|
||||||
TFE_OpAddInput(op.get(),
|
|
||||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
|
||||||
status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
|
||||||
} else {
|
|
||||||
// Parallel tensors are divided between operations by device.
|
|
||||||
TFE_OpAddInput(op.get(),
|
|
||||||
absl::get<ParallelTensor*>(inputs[input_index])
|
|
||||||
->tensor(device_index),
|
|
||||||
status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
|
||||||
int real_num_outputs = expected_max_outputs;
|
|
||||||
// For nested devices, the inner device sees the async executor we've
|
|
||||||
// set. Inner parallel devices will just overwrite this with their own and
|
|
||||||
// then set it back to ours before returning. This means parallel devices
|
|
||||||
// which consist of several aliased parallel devices would hypothetically
|
|
||||||
// deadlock if the outer parallel device ran one collective with a group
|
|
||||||
// size equal to the total number of aliased physical devices. Currently
|
|
||||||
// physical devices cannot participate in a single collective reduction
|
|
||||||
// multiple times, so this would fail earlier.
|
|
||||||
//
|
|
||||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
|
||||||
// rather than a single list of executors so aliased nested parallel devices
|
|
||||||
// don't re-use an executor.
|
|
||||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
|
||||||
if (device_index == 0) {
|
|
||||||
first_op_output_count = real_num_outputs;
|
|
||||||
} else {
|
|
||||||
if (real_num_outputs != first_op_output_count) {
|
|
||||||
TF_SetStatus(status, TF_INTERNAL,
|
|
||||||
"Parallel ops produced different numbers of tensors.");
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
|
||||||
std::vector<TensorHandlePtr> this_outputs;
|
|
||||||
this_outputs.reserve(real_num_outputs);
|
|
||||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
|
||||||
this_outputs.emplace_back(op_outputs[output_num]);
|
|
||||||
}
|
|
||||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
|
||||||
}
|
|
||||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
|
||||||
++device_index) {
|
|
||||||
TFE_Executor* executor = executors_[device_index].get();
|
|
||||||
// TODO(b/157523095): Syncing the executor here shouldn't be
|
|
||||||
// necessary. Currently async+remote is missing cross-executor
|
|
||||||
// coordination.
|
|
||||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
|
||||||
}
|
|
||||||
// For each output of the original operation, pack the per-device
|
|
||||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
|
||||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
|
||||||
per_device_outputs.reserve(first_op_output_count);
|
|
||||||
for (int i = 0; i < first_op_output_count; ++i) {
|
|
||||||
std::vector<TensorHandlePtr> components;
|
|
||||||
components.reserve(underlying_devices_.size());
|
|
||||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
|
||||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
|
||||||
}
|
|
||||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
|
||||||
*this, std::move(components), status));
|
|
||||||
if (TF_GetCode(status) != TF_OK) return result;
|
|
||||||
}
|
|
||||||
result.emplace(std::move(per_device_outputs));
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
|
||||||
const ParallelDevice& parallel_device,
|
|
||||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
|
||||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
|
||||||
std::vector<int64_t> shape(
|
|
||||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
|
||||||
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
|
||||||
// shapes and dtypes.
|
|
||||||
for (TensorHandlePtr& component : components) {
|
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
|
||||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
if (tensor_dim != shape[i]) {
|
|
||||||
// TODO(allenl): Allow shapes to differ.
|
|
||||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
|
||||||
"Components of a ParallelTensor must currently all have "
|
|
||||||
"the same shape");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
|
||||||
TF_SetStatus(status, TF_INTERNAL,
|
|
||||||
"Components of a ParallelTensor must all have "
|
|
||||||
"the same dtype");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
|
||||||
parallel_device, std::move(components), std::move(shape), dtype));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
|
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
|
||||||
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
|
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
|
||||||
// reference counts drop to zero.
|
// reference counts drop to zero.
|
||||||
|
@ -505,17 +165,18 @@ void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
|
||||||
delete reinterpret_cast<ParallelTensor*>(data);
|
delete reinterpret_cast<ParallelTensor*>(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandlePtr ParallelTensor::AsTensorHandle(
|
TensorHandlePtr ParallelTensorToTensorHandle(
|
||||||
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
|
const std::string& parallel_device_name, TFE_Context* context,
|
||||||
TF_Status* status) {
|
std::unique_ptr<ParallelTensor> t, TF_Status* status) {
|
||||||
// The resulting TensorHandle owns an opaque pointer to "device memory", which
|
// The resulting TensorHandle owns an opaque pointer to "device memory", which
|
||||||
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
|
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
|
||||||
// deleted, it will call ParallelTensorDeallocator to free the struct.
|
// deleted, it will call ParallelTensorDeallocator to free the struct.
|
||||||
ParallelTensor* t_released = t.release();
|
ParallelTensor* t_released = t.release();
|
||||||
|
const std::vector<int64_t>& shape(t_released->shape());
|
||||||
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
|
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
|
||||||
context, t_released->device_.device_name().c_str(), t_released->dtype_,
|
context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
|
||||||
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
|
shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
|
||||||
&ParallelTensorDeallocator, nullptr, status));
|
status));
|
||||||
}
|
}
|
||||||
|
|
||||||
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
|
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
|
||||||
|
@ -531,12 +192,14 @@ TensorHandlePtr ParallelTensor::AsTensorHandle(
|
||||||
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
|
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
|
||||||
TFE_TensorHandle* tensor,
|
TFE_TensorHandle* tensor,
|
||||||
TF_Status* status, void* device_info) {
|
TF_Status* status, void* device_info) {
|
||||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
NamedParallelDevice* named_device =
|
||||||
|
reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||||
|
const ParallelDevice& dev = named_device->device();
|
||||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||||
dev->CopyToParallelDevice(context, tensor, status));
|
dev.CopyToParallelDevice(context, tensor, status));
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
|
return ParallelTensorToTensorHandle(named_device->name(), context,
|
||||||
status)
|
std::move(parallel_tensor), status)
|
||||||
.release();
|
.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -570,14 +233,15 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||||
TFE_TensorHandle** outputs, TF_Status* status,
|
TFE_TensorHandle** outputs, TF_Status* status,
|
||||||
void* device_info) {
|
void* device_info) {
|
||||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
NamedParallelDevice* named_device =
|
||||||
|
reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||||
std::vector<MaybeParallelTensorUnowned> typed_inputs;
|
std::vector<MaybeParallelTensorUnowned> typed_inputs;
|
||||||
typed_inputs.reserve(num_inputs);
|
typed_inputs.reserve(num_inputs);
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
const char* tensor_handle_device =
|
const char* tensor_handle_device =
|
||||||
TFE_TensorHandleDeviceName(inputs[i], status);
|
TFE_TensorHandleDeviceName(inputs[i], status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
if (dev->device_name() == tensor_handle_device) {
|
if (named_device->name() == tensor_handle_device) {
|
||||||
// We assume that any tensors already placed on this device are
|
// We assume that any tensors already placed on this device are
|
||||||
// ParallelTensors.
|
// ParallelTensors.
|
||||||
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
|
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
|
||||||
|
@ -589,8 +253,9 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
|
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
|
||||||
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
|
ExecuteWithSpecialOps(named_device->device(), named_device->name(),
|
||||||
*num_outputs, status));
|
context, std::move(typed_inputs), operation_name,
|
||||||
|
attributes, *num_outputs, status));
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
if (!maybe_typed_outputs.has_value()) {
|
if (!maybe_typed_outputs.has_value()) {
|
||||||
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
|
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
|
||||||
|
@ -611,8 +276,8 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||||
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
|
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
|
||||||
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
|
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
|
||||||
} else {
|
} else {
|
||||||
outputs[i] = ParallelTensor::AsTensorHandle(
|
outputs[i] = ParallelTensorToTensorHandle(
|
||||||
context,
|
named_device->name(), context,
|
||||||
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
|
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
|
||||||
typed_output)),
|
typed_output)),
|
||||||
status)
|
status)
|
||||||
|
@ -629,7 +294,7 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||||
// device_info is passed in using a C-style generic. It must always be a
|
// device_info is passed in using a C-style generic. It must always be a
|
||||||
// ParallelDevice.
|
// ParallelDevice.
|
||||||
void DeleteParallelDevice(void* device_info) {
|
void DeleteParallelDevice(void* device_info) {
|
||||||
delete reinterpret_cast<ParallelDevice*>(device_info);
|
delete reinterpret_cast<NamedParallelDevice*>(device_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -648,8 +313,10 @@ void AllocateParallelDevice(const char* device_name,
|
||||||
++device_index) {
|
++device_index) {
|
||||||
underlying_devices_vector.push_back(underlying_devices[device_index]);
|
underlying_devices_vector.push_back(underlying_devices[device_index]);
|
||||||
}
|
}
|
||||||
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
|
std::unique_ptr<ParallelDevice> parallel_device(
|
||||||
|
new ParallelDevice(underlying_devices_vector));
|
||||||
|
*device_info =
|
||||||
|
new NamedParallelDevice{device_name, std::move(parallel_device)};
|
||||||
}
|
}
|
||||||
|
} // namespace parallel_device
|
||||||
} // namespace eager
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace eager {
|
namespace parallel_device {
|
||||||
|
|
||||||
// Allocate a parallel device named `device_name` which forwards operations to
|
// Allocate a parallel device named `device_name` which forwards operations to
|
||||||
// `underlying_devices`, maintaining "parallel tensors" with components placed
|
// `underlying_devices`, maintaining "parallel tensors" with components placed
|
||||||
|
@ -59,7 +59,7 @@ void AllocateParallelDevice(const char* device_name,
|
||||||
int num_underlying_devices,
|
int num_underlying_devices,
|
||||||
TFE_CustomDevice* device, void** device_info);
|
TFE_CustomDevice* device, void** device_info);
|
||||||
|
|
||||||
} // namespace eager
|
} // namespace parallel_device
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||||
|
|
|
@ -0,0 +1,251 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace parallel_device {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class OpDeleter {
|
||||||
|
public:
|
||||||
|
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
|
||||||
|
};
|
||||||
|
|
||||||
|
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||||
|
|
||||||
|
// Creates a vector of `count` new executors (threads).
|
||||||
|
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||||
|
std::vector<ExecutorPtr> executors;
|
||||||
|
executors.reserve(count);
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||||
|
}
|
||||||
|
return executors;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
|
||||||
|
: underlying_devices_(devices),
|
||||||
|
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||||
|
|
||||||
|
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||||
|
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||||
|
std::vector<TensorHandlePtr> components;
|
||||||
|
components.reserve(underlying_devices_.size());
|
||||||
|
for (const std::string& underlying_device_name : underlying_devices_) {
|
||||||
|
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||||
|
tensor, context, underlying_device_name.c_str(), status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
components.emplace_back(t);
|
||||||
|
}
|
||||||
|
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||||
|
status);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||||
|
TFE_Context* context, TF_Status* status) const {
|
||||||
|
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||||
|
std::vector<TensorHandlePtr> components;
|
||||||
|
components.reserve(underlying_devices_.size());
|
||||||
|
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||||
|
++device_index) {
|
||||||
|
int64_t* device_id = new int64_t;
|
||||||
|
*device_id = device_index;
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||||
|
TF_NewTensor(
|
||||||
|
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||||
|
sizeof(int64_t),
|
||||||
|
[](void* data, size_t, void* arg) {
|
||||||
|
delete reinterpret_cast<int64_t*>(data);
|
||||||
|
},
|
||||||
|
nullptr),
|
||||||
|
TF_DeleteTensor);
|
||||||
|
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||||
|
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||||
|
// device names repeatedly.
|
||||||
|
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||||
|
status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||||
|
TFE_TensorHandle* device_handle;
|
||||||
|
int num_outputs = 1;
|
||||||
|
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
components.emplace_back(device_handle);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
}
|
||||||
|
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||||
|
status);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||||
|
ParallelDevice::Execute(TFE_Context* context,
|
||||||
|
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||||
|
const char* operation_name,
|
||||||
|
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||||
|
TF_Status* status) const {
|
||||||
|
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||||
|
// Compute per-device per-output tensors
|
||||||
|
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||||
|
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||||
|
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||||
|
// setting the thread-local executor like this.
|
||||||
|
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||||
|
auto reset_executor =
|
||||||
|
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
|
||||||
|
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||||
|
TFE_DeleteExecutor(previous_executor);
|
||||||
|
});
|
||||||
|
int first_op_output_count;
|
||||||
|
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||||
|
++device_index) {
|
||||||
|
TFE_Executor* executor = executors_[device_index].get();
|
||||||
|
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||||
|
// the value before this function ran.
|
||||||
|
TFE_ContextSetExecutorForThread(context, executor);
|
||||||
|
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||||
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
|
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||||
|
status);
|
||||||
|
TFE_OpAddAttrs(op.get(), attributes);
|
||||||
|
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||||
|
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||||
|
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||||
|
// to each parallel operation.
|
||||||
|
//
|
||||||
|
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||||
|
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||||
|
// about things that are taken as inputs on the host or on their
|
||||||
|
// existing device (for multi-device functions).
|
||||||
|
TFE_OpAddInput(op.get(),
|
||||||
|
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||||
|
status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
|
} else {
|
||||||
|
// Parallel tensors are divided between operations by device.
|
||||||
|
TFE_OpAddInput(op.get(),
|
||||||
|
absl::get<ParallelTensor*>(inputs[input_index])
|
||||||
|
->tensor(device_index),
|
||||||
|
status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||||
|
int real_num_outputs = expected_max_outputs;
|
||||||
|
// For nested devices, the inner device sees the async executor we've
|
||||||
|
// set. Inner parallel devices will just overwrite this with their own and
|
||||||
|
// then set it back to ours before returning. This means parallel devices
|
||||||
|
// which consist of several aliased parallel devices would hypothetically
|
||||||
|
// deadlock if the outer parallel device ran one collective with a group
|
||||||
|
// size equal to the total number of aliased physical devices. Currently
|
||||||
|
// physical devices cannot participate in a single collective reduction
|
||||||
|
// multiple times, so this would fail earlier.
|
||||||
|
//
|
||||||
|
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||||
|
// rather than a single list of executors so aliased nested parallel devices
|
||||||
|
// don't re-use an executor.
|
||||||
|
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||||
|
if (device_index == 0) {
|
||||||
|
first_op_output_count = real_num_outputs;
|
||||||
|
} else {
|
||||||
|
if (real_num_outputs != first_op_output_count) {
|
||||||
|
TF_SetStatus(status, TF_INTERNAL,
|
||||||
|
"Parallel ops produced different numbers of tensors.");
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
|
std::vector<TensorHandlePtr> this_outputs;
|
||||||
|
this_outputs.reserve(real_num_outputs);
|
||||||
|
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||||
|
this_outputs.emplace_back(op_outputs[output_num]);
|
||||||
|
}
|
||||||
|
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||||
|
}
|
||||||
|
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||||
|
++device_index) {
|
||||||
|
TFE_Executor* executor = executors_[device_index].get();
|
||||||
|
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||||
|
// necessary. Currently async+remote is missing cross-executor
|
||||||
|
// coordination.
|
||||||
|
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
|
}
|
||||||
|
// For each output of the original operation, pack the per-device
|
||||||
|
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||||
|
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||||
|
per_device_outputs.reserve(first_op_output_count);
|
||||||
|
for (int i = 0; i < first_op_output_count; ++i) {
|
||||||
|
std::vector<TensorHandlePtr> components;
|
||||||
|
components.reserve(underlying_devices_.size());
|
||||||
|
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||||
|
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||||
|
}
|
||||||
|
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||||
|
*this, std::move(components), status));
|
||||||
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
|
}
|
||||||
|
result.emplace(std::move(per_device_outputs));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||||
|
const ParallelDevice& parallel_device,
|
||||||
|
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||||
|
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||||
|
std::vector<int64_t> shape(
|
||||||
|
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||||
|
// shapes and dtypes.
|
||||||
|
for (TensorHandlePtr& component : components) {
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
if (tensor_dim != shape[i]) {
|
||||||
|
// TODO(allenl): Allow shapes to differ.
|
||||||
|
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||||
|
"Components of a ParallelTensor must currently all have "
|
||||||
|
"the same shape");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||||
|
TF_SetStatus(status, TF_INTERNAL,
|
||||||
|
"Components of a ParallelTensor must all have "
|
||||||
|
"the same dtype");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||||
|
parallel_device, std::move(components), std::move(shape), dtype));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace parallel_device
|
||||||
|
} // namespace tensorflow
|
|
@ -0,0 +1,141 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace parallel_device {
|
||||||
|
|
||||||
|
// Functor for making unique_ptrs slightly more ergonomic. Using
|
||||||
|
// decltype(delete_fn) in the unique_ptr's second template argument requires
|
||||||
|
// passing a function pointer to delete_fn when constructing the unique_ptr.
|
||||||
|
class TensorHandleDeleter {
|
||||||
|
public:
|
||||||
|
void operator()(TFE_TensorHandle* to_delete) const {
|
||||||
|
TFE_DeleteTensorHandle(to_delete);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||||
|
|
||||||
|
class ExecutorDeleter {
|
||||||
|
public:
|
||||||
|
void operator()(TFE_Executor* to_delete) const {
|
||||||
|
TFE_DeleteExecutor(to_delete);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||||
|
|
||||||
|
class ParallelTensor;
|
||||||
|
|
||||||
|
using MaybeParallelTensorUnowned =
|
||||||
|
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||||
|
|
||||||
|
// Forwards operations to `devices`, maintaining ParallelTensor with components
|
||||||
|
// placed on each underlying device.
|
||||||
|
class ParallelDevice {
|
||||||
|
public:
|
||||||
|
explicit ParallelDevice(const std::vector<std::string>& devices);
|
||||||
|
|
||||||
|
// Helper to copy a tensor handle from another device once for each component
|
||||||
|
// of the ParallelDevice.
|
||||||
|
//
|
||||||
|
// Sets a bad status and returns a nullptr if `tensor` is already on the
|
||||||
|
// ParallelDevice, or if the individual copies fail.
|
||||||
|
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
|
||||||
|
TFE_TensorHandle* tensor,
|
||||||
|
TF_Status* status) const;
|
||||||
|
|
||||||
|
// A parallel tensor with scalar integers numbering component devices.
|
||||||
|
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||||
|
TF_Status* status) const;
|
||||||
|
|
||||||
|
// The number of devices operations run on.
|
||||||
|
size_t num_underlying_devices() const { return underlying_devices_.size(); }
|
||||||
|
|
||||||
|
// Takes a description of a single operation being executed on the
|
||||||
|
// ParallelDevice, and in turn runs one operation per component device with
|
||||||
|
// its corresponding inputs from the input ParallelTensors (or
|
||||||
|
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||||
|
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||||
|
// output of the original operation.
|
||||||
|
//
|
||||||
|
// Attributes are forwarded to executed operations unmodified.
|
||||||
|
//
|
||||||
|
// The returned optional has a value if and only if `status` evaluates to
|
||||||
|
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||||
|
// if sanity checks on dtypes/metadata fail.
|
||||||
|
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||||
|
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||||
|
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||||
|
int expected_max_outputs, TF_Status* status) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// A sequence of device names, indicating which devices replicated operations
|
||||||
|
// are forwarded to.
|
||||||
|
const std::vector<std::string> underlying_devices_;
|
||||||
|
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||||
|
// parallel.
|
||||||
|
const std::vector<ExecutorPtr> executors_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
|
||||||
|
// ParallelDevice.
|
||||||
|
class ParallelTensor {
|
||||||
|
public:
|
||||||
|
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||||
|
// devices of a ParallelDevice.
|
||||||
|
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||||
|
const ParallelDevice& parallel_device,
|
||||||
|
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||||
|
|
||||||
|
size_t num_tensors() const { return tensors_.size(); }
|
||||||
|
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||||
|
|
||||||
|
// A generalization of the shapes of the underlying tensors.
|
||||||
|
const std::vector<int64_t>& shape() const { return shape_; }
|
||||||
|
TF_DataType dtype() const { return dtype_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
ParallelTensor(const ParallelDevice& device,
|
||||||
|
std::vector<TensorHandlePtr> tensors,
|
||||||
|
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||||
|
: device_(device),
|
||||||
|
tensors_(std::move(tensors)),
|
||||||
|
shape_(std::move(shape)),
|
||||||
|
dtype_(dtype) {}
|
||||||
|
|
||||||
|
const ParallelDevice& device_;
|
||||||
|
const std::vector<TensorHandlePtr> tensors_;
|
||||||
|
const std::vector<int64_t> shape_;
|
||||||
|
const TF_DataType dtype_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace parallel_device
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
|
|
@ -165,7 +165,7 @@ void RegisterParallelDevice(
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
TFE_CustomDevice device;
|
TFE_CustomDevice device;
|
||||||
void* device_info;
|
void* device_info;
|
||||||
tensorflow::eager::AllocateParallelDevice(
|
tensorflow::parallel_device::AllocateParallelDevice(
|
||||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||||
&device, &device_info);
|
&device, &device_info);
|
||||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||||
|
|
|
@ -52,7 +52,7 @@ PYBIND11_MODULE(_pywrap_parallel_device, m) {
|
||||||
tensorflow::Safe_PyObjectPtr device_capsule(
|
tensorflow::Safe_PyObjectPtr device_capsule(
|
||||||
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
|
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
|
||||||
void* device_info;
|
void* device_info;
|
||||||
tensorflow::eager::AllocateParallelDevice(
|
tensorflow::parallel_device::AllocateParallelDevice(
|
||||||
name, underlying_devices_c.data(), underlying_devices_c.size(),
|
name, underlying_devices_c.data(), underlying_devices_c.size(),
|
||||||
device, &device_info);
|
device, &device_info);
|
||||||
if (PyErr_Occurred()) throw py::error_already_set();
|
if (PyErr_Occurred()) throw py::error_already_set();
|
||||||
|
|
Loading…
Reference in New Issue