Introduce PackedTensorHandleData to TensorHandle. A PackedTensorHandleData refers to a list of TensorHandles of the same dtype and shape.
PiperOrigin-RevId: 308702161 Change-Id: Ide047f4cde1c17e7be9e0d64f78f499a022a430e
This commit is contained in:
parent
ea74b01d80
commit
4430ba27bb
@ -1025,10 +1025,10 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return t->data();
|
||||
}
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDevicePointer may not be called on a remote tensor "
|
||||
"handle.");
|
||||
"TFE_TensorHandleDevicePointer may not be called on a ",
|
||||
handle->TypeString(), " tensor handle.");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
|
||||
@ -1099,10 +1099,10 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
|
||||
"handle.");
|
||||
"TFE_TensorHandleDeviceMemorySize may not be called on a ",
|
||||
handle->TypeString(), " tensor handle.");
|
||||
return 0;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
|
@ -47,9 +47,9 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
if (handle->Type() != TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"DLPack doesn't support remote tensor");
|
||||
"DLPack doesn't support ", handle->TypeString(), " tensor");
|
||||
return nullptr;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
|
@ -30,6 +30,14 @@ std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
||||
errors::InvalidArgument("underlying_devices should not be empty."));
|
||||
return nullptr;
|
||||
}
|
||||
std::set<string> unique_devices;
|
||||
for (const string& device : underlying_devices) {
|
||||
if (!unique_devices.insert(device).second) {
|
||||
status->Update(errors::InvalidArgument(
|
||||
"Got a duplicated device in underlying_devices: ", device));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!DeviceNameUtils::ParseFullName(underlying_devices.at(0), &parsed_name)) {
|
||||
status->Update(tensorflow::errors::InvalidArgument(
|
||||
|
@ -47,6 +47,21 @@ TEST(CompositeDeviceTest, Basic) {
|
||||
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
|
||||
}
|
||||
|
||||
{
|
||||
Status status;
|
||||
underlying_devices.push_back(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
std::unique_ptr<CompositeDevice> composite_device =
|
||||
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
|
||||
&status);
|
||||
EXPECT_EQ(composite_device, nullptr);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(status.error_message(), "Got a duplicated device"))
|
||||
<< status.ToString();
|
||||
underlying_devices.pop_back();
|
||||
}
|
||||
|
||||
{
|
||||
Status status;
|
||||
underlying_devices.push_back(
|
||||
|
@ -72,6 +72,7 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":eager_executor",
|
||||
":kernel_and_device",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:context_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
|
@ -883,6 +883,29 @@ Status EagerContext::RegisterCustomDevice(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerContext::FindOrCreateCompositeDevice(
|
||||
const std::vector<string>& underlying_devices,
|
||||
CompositeDevice** composite_device) {
|
||||
const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
|
||||
|
||||
mutex_lock l(composite_devices_mu_);
|
||||
auto iter = composite_devices_.find(hash_key);
|
||||
if (iter != composite_devices_.end()) {
|
||||
*composite_device = iter->second.get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status s;
|
||||
auto device = CompositeDevice::MakeDevice(underlying_devices,
|
||||
composite_devices_.size(), &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
*composite_device = device.get();
|
||||
// TODO(b/145922293): Add the composite device to the device set of pflr in
|
||||
// order to make placer recognize it.
|
||||
composite_devices_.emplace(hash_key, std::move(device));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
|
||||
if (first == nullptr) first = HostCPU();
|
||||
if (second == nullptr) second = HostCPU();
|
||||
|
@ -32,8 +32,10 @@ limitations under the License.
|
||||
// clang-format on
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
@ -487,6 +489,11 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
Status RegisterCustomDevice(const string& name,
|
||||
std::unique_ptr<CustomDevice> device);
|
||||
|
||||
// Find or create a composite device with the given `underlying_devices`.
|
||||
Status FindOrCreateCompositeDevice(
|
||||
const std::vector<string>& underlying_devices,
|
||||
CompositeDevice** composite_device);
|
||||
|
||||
bool OnSameTask(const Device* first, const Device* second) const;
|
||||
// Gets the CPU device on the task of device.
|
||||
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
|
||||
@ -569,6 +576,13 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
std::function<Rendezvous*(const int64)> rendezvous_creator_;
|
||||
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
|
||||
|
||||
mutable mutex composite_devices_mu_;
|
||||
// Maps from the fingerprint of a set of device names to a virtual
|
||||
// CompositeDevice.
|
||||
// TODO(b/145922293): Consider taking device names as keys.
|
||||
absl::flat_hash_map<uint64, std::unique_ptr<CompositeDevice>>
|
||||
composite_devices_ GUARDED_BY(composite_devices_mu_);
|
||||
|
||||
FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
|
||||
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
|
@ -170,5 +170,27 @@ TEST_F(EagerContextTest, SelectDeviceExplicitSoftPlacement) {
|
||||
EXPECT_EQ(dev->device_type(), DEVICE_CPU);
|
||||
}
|
||||
|
||||
TEST_F(EagerContextTest, CompositeDevice) {
|
||||
InitContext(SessionOptions(), DEVICE_PLACEMENT_EXPLICIT);
|
||||
std::vector<string> underlying_devices = {
|
||||
"/job:worker/replica:0/task:0/device:CPU:0",
|
||||
"/job:worker/replica:0/task:0/device:CPU:1"};
|
||||
CompositeDevice* composite_device_0 = nullptr;
|
||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||
&composite_device_0));
|
||||
EXPECT_EQ(composite_device_0->name(),
|
||||
"/job:worker/replica:0/task:0/device:COMPOSITE:0");
|
||||
CompositeDevice* composite_device_1 = nullptr;
|
||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||
&composite_device_1));
|
||||
EXPECT_EQ(composite_device_1, composite_device_0);
|
||||
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:2");
|
||||
CompositeDevice* composite_device_2 = nullptr;
|
||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||
&composite_device_2));
|
||||
EXPECT_EQ(composite_device_2->name(),
|
||||
"/job:worker/replica:0/task:0/device:COMPOSITE:1");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -48,7 +48,7 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
|
||||
}
|
||||
}
|
||||
|
||||
if (IsRemote()) {
|
||||
if (Type() == REMOTE) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
TensorHandle* h_cpu = nullptr;
|
||||
*status = EagerCopyToDevice(this, ctx_, &ctx_->Executor(), ctx_->HostCPU(),
|
||||
@ -68,7 +68,7 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
|
||||
h_cpu->Unref();
|
||||
delete tf_tensor;
|
||||
return retval;
|
||||
} else {
|
||||
} else if (Type() == LOCAL) {
|
||||
tensorflow::Tensor tensor;
|
||||
if (IsCPU(device()) || HasLocalMirror(nullptr)) {
|
||||
const tensorflow::Tensor* src = nullptr;
|
||||
@ -95,6 +95,10 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
|
||||
AbstractTensorInterface* retval = tf_tensor->tensor;
|
||||
delete tf_tensor;
|
||||
return retval;
|
||||
} else {
|
||||
*status = errors::InvalidArgument(
|
||||
"Resolve() is not supoorted on packed TensorHandles.");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -200,7 +200,8 @@ Status ValidateInputTypeAndPlacement(
|
||||
"together.");
|
||||
}
|
||||
Device* handle_device = absl::get<Device*>(handle_device_variant);
|
||||
const bool maybe_copy = !skip_remote_copy || !handle->IsRemote();
|
||||
const bool maybe_copy =
|
||||
!skip_remote_copy || handle->Type() != TensorHandle::REMOTE;
|
||||
// If the input is already on the right device, then nothing to do.
|
||||
if (expected_device != handle_device && maybe_copy) {
|
||||
TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
|
||||
@ -258,7 +259,7 @@ Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
|
||||
}
|
||||
Device* cpu_device = ctx.HostCPU();
|
||||
string device_name;
|
||||
if (tensor_handle->IsRemote()) {
|
||||
if (tensor_handle->Type() != TensorHandle::LOCAL) {
|
||||
Device* device = absl::get<Device*>(tensor_handle->device());
|
||||
device_name = device != nullptr ? device->name() : cpu_device->name();
|
||||
*result = (device == nullptr ? cpu_device : device);
|
||||
@ -392,7 +393,8 @@ Status GetOrCreateKernelAndDevice(
|
||||
// which doesn't accept remote inputs.
|
||||
for (int i = 0; i < op->Inputs().size(); i++) {
|
||||
TensorHandle* input = op->Inputs()[i];
|
||||
if (!ctx.LazyCopyFunctionRemoteInputs() && input->IsRemote()) {
|
||||
if (!ctx.LazyCopyFunctionRemoteInputs() &&
|
||||
input->Type() == TensorHandle::REMOTE) {
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
EagerCopyToDevice(input, &ctx, &op->Executor(),
|
||||
|
@ -36,7 +36,8 @@ Status ExecuteNodeArgs::Init(
|
||||
if (!s.ok()) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
uint64 context_view_id = ctx->GetContextViewId();
|
||||
if (in->IsRemote() || in->HasRemoteMirror(d, context_view_id)) {
|
||||
if (in->Type() == TensorHandle::REMOTE ||
|
||||
in->HasRemoteMirror(d, context_view_id)) {
|
||||
if (!has_remote_inputs_) {
|
||||
has_remote_inputs_ = true;
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
@ -47,6 +48,91 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TensorHandle::PackedTensorHandleData::PackedTensorHandleData(
|
||||
std::vector<TensorHandle*>&& handles, const TensorShape& shape)
|
||||
: handles_(std::move(handles)), shape_(shape) {
|
||||
for (auto* handle : handles_) {
|
||||
handle->Ref();
|
||||
}
|
||||
}
|
||||
|
||||
TensorHandle::PackedTensorHandleData::~PackedTensorHandleData() {
|
||||
for (auto* handle : handles_) {
|
||||
handle->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
Status TensorHandle::PackedTensorHandleData::Shape(TensorShape* shape) const {
|
||||
*shape = shape_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::PackedTensorHandleData::NumDims(int* num_dims) const {
|
||||
*num_dims = shape_.dims();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::PackedTensorHandleData::Dim(int dim_index,
|
||||
int64* dim) const {
|
||||
*dim = shape_.dim_size(dim_index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::PackedTensorHandleData::NumElements(
|
||||
int64* num_elements) const {
|
||||
*num_elements = shape_.num_elements();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::PackedTensorHandleData::Unprotect() {
|
||||
for (auto* handle : handles_) {
|
||||
TF_RETURN_IF_ERROR(absl::visit([](auto& data) { return data.Unprotect(); },
|
||||
handle->data_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TensorHandle::PackedTensorHandleData::IsReady() const {
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
if (!is_poisoned_.ok()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (auto* handle : handles_) {
|
||||
if (!handle->IsReady()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void TensorHandle::PackedTensorHandleData::Poison(Status status) {
|
||||
mutex_lock l(mu_);
|
||||
is_poisoned_ = status;
|
||||
}
|
||||
|
||||
string TensorHandle::PackedTensorHandleData::DebugString() const {
|
||||
string debug_str = "PackedTensorHandleData: ";
|
||||
for (const auto* handle : handles_) {
|
||||
debug_str.append(
|
||||
absl::StrCat(absl::visit([](auto& data) { return data.DebugString(); },
|
||||
handle->data_),
|
||||
"; "));
|
||||
}
|
||||
return debug_str;
|
||||
}
|
||||
|
||||
Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
|
||||
const int index, TensorHandle** handle) const {
|
||||
if (index < 0 || index >= handles_.size()) {
|
||||
return errors::InvalidArgument("Expect an index within [0, ",
|
||||
handles_.size(), "), but got ", index);
|
||||
}
|
||||
*handle = handles_.at(index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TensorHandle::SetResourceHandleInfo(
|
||||
ResourceHandleInfo&& resource_handle_info) {
|
||||
resource_handle_info_ = std::move(resource_handle_info);
|
||||
@ -61,7 +147,7 @@ Status TensorHandle::GetResourceHandleInfoImpl(
|
||||
dtype);
|
||||
}
|
||||
|
||||
if (IsRemote()) {
|
||||
if (Type() != LOCAL) {
|
||||
set_resource_info();
|
||||
return Status::OK();
|
||||
}
|
||||
@ -98,6 +184,16 @@ Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
|
||||
return GetResourceHandleInfoImpl(get_resource_info);
|
||||
}
|
||||
|
||||
Status TensorHandle::ExtractPackedHandle(const int index,
|
||||
TensorHandle** handle) const {
|
||||
if (Type() != PACKED) {
|
||||
return errors::Internal("Invalid ExtractPackedHandleOnDevice call on a",
|
||||
TypeString(), " handle: ", this);
|
||||
}
|
||||
return absl::get<PackedTensorHandleData>(data_).ExtractPackedHandle(index,
|
||||
handle);
|
||||
}
|
||||
|
||||
TensorHandle* TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t) {
|
||||
// TODO(b/136608821): Move away from nullptr
|
||||
tensorflow::Tensor tensor = t;
|
||||
@ -195,6 +291,64 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
}
|
||||
|
||||
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
EagerContext* ctx,
|
||||
TensorHandle** packed_handle) {
|
||||
if (handles.empty()) {
|
||||
return errors::InvalidArgument("Handles should not be empty.");
|
||||
}
|
||||
|
||||
// Get the dtype and shape from the fisrt handle since all handles have the
|
||||
// same dtype and shape.
|
||||
tensorflow::DataType dtype = handles.at(0)->dtype;
|
||||
tensorflow::TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
|
||||
ResourceHandleInfo resource_handle_info;
|
||||
if (dtype == DT_RESOURCE) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
handles.at(0)->GetResourceHandleInfo(&resource_handle_info));
|
||||
}
|
||||
std::vector<string> devices;
|
||||
for (auto* handle : handles) {
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
return errors::InvalidArgument(
|
||||
"CustomDevice is not supported for packing.");
|
||||
} else {
|
||||
devices.push_back(
|
||||
absl::get<Device*>(handle->DeviceOrHostCPU(*ctx))->name());
|
||||
}
|
||||
}
|
||||
|
||||
Device* device;
|
||||
if (devices.size() == 1) {
|
||||
device = absl::get<Device*>(handles.at(0)->DeviceOrHostCPU(*ctx));
|
||||
} else {
|
||||
CompositeDevice* composite_device = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->FindOrCreateCompositeDevice(devices, &composite_device));
|
||||
device = composite_device;
|
||||
}
|
||||
*packed_handle =
|
||||
new TensorHandle(std::move(handles), device, dtype, shape, ctx);
|
||||
(*packed_handle)->SetResourceHandleInfo(std::move(resource_handle_info));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
||||
const tensorflow::DataType dtype,
|
||||
const tensorflow::TensorShape& shape,
|
||||
EagerContext* ctx)
|
||||
: dtype(dtype),
|
||||
device_(device),
|
||||
op_device_(device),
|
||||
resource_device_(dtype == DT_RESOURCE ? device : nullptr),
|
||||
ctx_(ctx),
|
||||
data_(absl::in_place_type<PackedTensorHandleData>, std::move(handles),
|
||||
shape) {
|
||||
DVLOG(3) << "Creating a packed TensorHandle: " << this
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
TensorHandle* TensorHandle::CreateUnshapedRemoteHandle(
|
||||
int64 op_id, int32 output_num, const string& remote_task,
|
||||
@ -253,19 +407,32 @@ bool TensorHandle::IsReady() const {
|
||||
return absl::visit([](auto& data) { return data.IsReady(); }, data_);
|
||||
}
|
||||
|
||||
bool TensorHandle::IsRemote() const {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
return data_.index() == 1;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
TensorHandle::HandleType TensorHandle::Type() const {
|
||||
if (data_.index() == 0) {
|
||||
return LOCAL;
|
||||
} else if (data_.index() == 1) {
|
||||
return PACKED;
|
||||
} else {
|
||||
return REMOTE;
|
||||
}
|
||||
}
|
||||
|
||||
string TensorHandle::TypeString() const {
|
||||
if (data_.index() == 0) {
|
||||
return "LOCAL";
|
||||
} else if (data_.index() == 1) {
|
||||
return "PACKED";
|
||||
} else {
|
||||
return "REMOTE";
|
||||
}
|
||||
}
|
||||
|
||||
Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
|
||||
DVLOG(3) << "Tensor on TensorHandle: " << this;
|
||||
|
||||
if (IsRemote()) {
|
||||
return errors::Internal("Invalid Tensor call on remote handle: ", this);
|
||||
if (Type() != LOCAL) {
|
||||
return errors::Internal("Invalid Tensor call on a ", TypeString(),
|
||||
" handle: ", this);
|
||||
}
|
||||
|
||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
||||
@ -277,8 +444,9 @@ Status TensorHandle::TensorFromDevice(const Device* d,
|
||||
DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (d == absl::get<Device*>(device_)) {
|
||||
if (IsRemote()) {
|
||||
return errors::Internal("Invalid Tensor call on remote handle: ", this);
|
||||
if (Type() != LOCAL) {
|
||||
return errors::Internal("Invalid Tensor call on a ", TypeString(),
|
||||
" handle: ", this);
|
||||
}
|
||||
|
||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
||||
@ -306,9 +474,9 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
|
||||
VariantDeviceDebugString(device_),
|
||||
", requested device: ", d != nullptr ? d->name() : "(nil)");
|
||||
} else if (d == absl::get<Device*>(device_)) {
|
||||
if (IsRemote()) {
|
||||
return errors::Internal("Invalid TensorValue call on remote handle: ",
|
||||
this);
|
||||
if (Type() != LOCAL) {
|
||||
return errors::Internal("Invalid TensorValue call on a ", TypeString(),
|
||||
" handle: ", this);
|
||||
}
|
||||
|
||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
||||
@ -449,9 +617,8 @@ Status TensorHandle::NumElements(int64* num_elements) const {
|
||||
Status TensorHandle::Unprotect(const Device* d) {
|
||||
DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (!IsRemote() && (d == absl::get<Device*>(device_))) {
|
||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
||||
return data.Unprotect();
|
||||
if (d == absl::get<Device*>(device_)) {
|
||||
return absl::visit([](auto& data) { return data.Unprotect(); }, data_);
|
||||
}
|
||||
|
||||
tf_shared_lock l(mu_);
|
||||
@ -511,7 +678,7 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
|
||||
"Could not find remote mirror for specified device");
|
||||
}
|
||||
|
||||
if (!IsRemote()) {
|
||||
if (Type() != REMOTE) {
|
||||
return errors::InvalidArgument("Primary device is not remote");
|
||||
}
|
||||
|
||||
@ -623,7 +790,8 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DCHECK(IsRemote()) << "SetRemoteShape is only called on remote handles.";
|
||||
DCHECK(Type() == REMOTE)
|
||||
<< "SetRemoteShape is only called on remote handles.";
|
||||
|
||||
auto& data = absl::get<RemoteTensorHandleData>(data_);
|
||||
// context_view_id is currently used to validate mirrors. The shape of
|
||||
@ -643,7 +811,8 @@ void TensorHandle::PoisonRemote(Status status, const Device* d,
|
||||
<< " " << d->name();
|
||||
|
||||
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
|
||||
DCHECK(IsRemote()) << "Poison can only be on remote handles: " << this;
|
||||
DCHECK(Type() == REMOTE)
|
||||
<< "Poison can only be on remote handles: " << this;
|
||||
|
||||
auto& data = absl::get<RemoteTensorHandleData>(data_);
|
||||
data.Poison(status);
|
||||
@ -681,7 +850,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
|
||||
DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (d == absl::get<Device*>(device_)) {
|
||||
DCHECK(!IsRemote()) << "SetTensor is not called on remote handles.";
|
||||
DCHECK(Type() == LOCAL) << "SetTensor is not called on local handles.";
|
||||
|
||||
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
|
||||
auto& resource_handle = t.flat<class ResourceHandle>()(0);
|
||||
@ -709,10 +878,8 @@ void TensorHandle::Poison(Status status, const Device* d) {
|
||||
DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
|
||||
DCHECK(!IsRemote()) << "Poison can only be on local handles: " << this;
|
||||
|
||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
||||
data.Poison(status);
|
||||
DCHECK(Type() != REMOTE) << "Poison can only be on local handles: " << this;
|
||||
absl::visit([status](auto& data) { data.Poison(status); }, data_);
|
||||
} else {
|
||||
tf_shared_lock l(mu_);
|
||||
auto elem = local_mirrors_.find(d);
|
||||
|
@ -87,6 +87,14 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
Device* resource_device,
|
||||
tensorflow::DataType dtype,
|
||||
EagerContext* ctx);
|
||||
|
||||
// Create a handle which packs the given handles of the same dtype and shape.
|
||||
// If handles are on different devices, assign the packed handle to a
|
||||
// CompositeDevice.
|
||||
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
EagerContext* ctx,
|
||||
TensorHandle** packed_handle);
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
static TensorHandle* CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
||||
const string& remote_task,
|
||||
@ -200,7 +208,10 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
// ready.
|
||||
const tensorflow::DataType dtype;
|
||||
|
||||
bool IsRemote() const;
|
||||
enum HandleType { LOCAL = 0, PACKED = 1, REMOTE = 2 };
|
||||
|
||||
HandleType Type() const;
|
||||
string TypeString() const;
|
||||
|
||||
string DebugString() const;
|
||||
|
||||
@ -218,7 +229,17 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
std::vector<DtypeAndPartialTensorShape>* result);
|
||||
Status GetResourceAllowedDevices(std::vector<string>* result);
|
||||
|
||||
// It's called on a packed TensorHandle. Extract a handle with the given
|
||||
// index.
|
||||
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
|
||||
|
||||
private:
|
||||
friend class PackedTensorHandleTest;
|
||||
|
||||
TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
||||
const tensorflow::DataType dtype,
|
||||
const tensorflow::TensorShape& shape, EagerContext* ctx);
|
||||
|
||||
~TensorHandle() override;
|
||||
|
||||
// The TensorHandleData can either represent a local or remote tensor handle.
|
||||
@ -275,12 +296,43 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
// devices for the underlying resource.
|
||||
ResourceHandleInfo resource_handle_info_;
|
||||
|
||||
// A handle data which refers to multiple TensorHandles of the same dtype and
|
||||
// shape.
|
||||
class PackedTensorHandleData {
|
||||
public:
|
||||
PackedTensorHandleData(std::vector<TensorHandle*>&& handles,
|
||||
const TensorShape& shape);
|
||||
|
||||
~PackedTensorHandleData();
|
||||
|
||||
Status Shape(TensorShape* shape) const;
|
||||
Status NumDims(int* num_dims) const;
|
||||
Status Dim(int dim_index, int64* dim) const;
|
||||
Status NumElements(int64* num_elements) const;
|
||||
Status Unprotect();
|
||||
bool IsReady() const;
|
||||
void Poison(Status status);
|
||||
string DebugString() const;
|
||||
|
||||
// Extract a handle on the given index.
|
||||
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
|
||||
|
||||
private:
|
||||
const std::vector<TensorHandle*> handles_;
|
||||
const TensorShape shape_;
|
||||
|
||||
mutable mutex mu_;
|
||||
Status is_poisoned_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// Does not need synchronization because it can be accessed only after
|
||||
// WaitReady() has returned. At that point, data_ is immutable.
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
absl::variant<LocalTensorHandleData, RemoteTensorHandleData> data_;
|
||||
absl::variant<LocalTensorHandleData, PackedTensorHandleData,
|
||||
RemoteTensorHandleData>
|
||||
data_;
|
||||
#else
|
||||
absl::variant<LocalTensorHandleData> data_;
|
||||
absl::variant<LocalTensorHandleData, PackedTensorHandleData> data_;
|
||||
#endif
|
||||
|
||||
PartialTensorShape inference_shape_;
|
||||
|
@ -15,8 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -64,5 +66,141 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
|
||||
ctx->Unref();
|
||||
}
|
||||
|
||||
static Device* CreateDevice(const char* type, const char* name) {
|
||||
class FakeDevice : public Device {
|
||||
public:
|
||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||
Status Sync() override { return Status::OK(); }
|
||||
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
||||
};
|
||||
DeviceAttributes attr;
|
||||
attr.set_name(name);
|
||||
attr.set_device_type(type);
|
||||
return new FakeDevice(attr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class PackedTensorHandleTest : public ::testing::Test {
|
||||
public:
|
||||
PackedTensorHandleTest() {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
for (const char* name : device_names_) {
|
||||
devices.emplace_back(CreateDevice("GPU", name));
|
||||
}
|
||||
device_mgr_ = new StaticDeviceMgr(std::move(devices));
|
||||
|
||||
context_ = new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_,
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr);
|
||||
}
|
||||
|
||||
~PackedTensorHandleTest() override {
|
||||
delete device_mgr_;
|
||||
context_->Unref();
|
||||
}
|
||||
|
||||
EagerContext* context() { return context_; }
|
||||
|
||||
std::vector<Device*> ListDevices() const {
|
||||
return device_mgr_->ListDevices();
|
||||
}
|
||||
|
||||
bool IsReady(TensorHandle* handle) const { return handle->IsReady(); }
|
||||
|
||||
private:
|
||||
const std::vector<const char*> device_names_ = {
|
||||
"/job:worker/replica:0/task:0/device:GPU:0",
|
||||
"/job:worker/replica:0/task:0/device:GPU:1",
|
||||
"/job:worker/replica:0/task:1/device:GPU:0",
|
||||
"/job:worker/replica:0/task:1/device:GPU:1"};
|
||||
|
||||
StaticDeviceMgr* device_mgr_;
|
||||
EagerContext* context_;
|
||||
};
|
||||
|
||||
TEST_F(PackedTensorHandleTest, PackedHandle) {
|
||||
tensorflow::DataType dtype = DT_RESOURCE;
|
||||
TensorShape shape = {};
|
||||
DtypeAndPartialTensorShape dtype_and_shape = {DT_FLOAT, {2, 2}};
|
||||
|
||||
// Create 2 local TensorHandles (ready)
|
||||
std::vector<TensorHandle*> handles;
|
||||
Tensor t0(dtype, shape);
|
||||
Device* d0 = ListDevices().at(0);
|
||||
TensorHandle* h0 =
|
||||
TensorHandle::CreateLocalHandle(std::move(t0), d0, d0, d0, context());
|
||||
h0->SetResourceHandleInfo({{dtype_and_shape}, {}});
|
||||
handles.push_back(h0);
|
||||
Tensor t1(dtype, shape);
|
||||
Device* d1 = ListDevices().at(1);
|
||||
TensorHandle* h1 =
|
||||
TensorHandle::CreateLocalHandle(std::move(t1), d1, d1, d1, context());
|
||||
h1->SetResourceHandleInfo({{dtype_and_shape}, {}});
|
||||
handles.push_back(h1);
|
||||
|
||||
// Create 2 remote TensorHandles (not ready).
|
||||
const string remote_task = "/job:worker/replica:0/task:1";
|
||||
Device* d2 = ListDevices().at(2);
|
||||
TensorHandle* h2 = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
/*op_id=*/0, /*output_num=*/0, remote_task, dtype, d2, context());
|
||||
handles.push_back(h2);
|
||||
Device* d3 = ListDevices().at(3);
|
||||
TensorHandle* h3 = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
/*op_id=*/1, /*output_num=*/0, remote_task, dtype, d3, context());
|
||||
handles.push_back(h3);
|
||||
|
||||
TensorHandle* packed_handle = nullptr;
|
||||
TF_EXPECT_OK(TensorHandle::CreatePackedHandle(std::move(handles), context(),
|
||||
&packed_handle));
|
||||
|
||||
h0->Unref();
|
||||
h1->Unref();
|
||||
h2->Unref();
|
||||
h3->Unref();
|
||||
|
||||
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
|
||||
EXPECT_EQ(packed_handle->dtype, dtype);
|
||||
TensorShape packed_shape;
|
||||
TF_ASSERT_OK(packed_handle->Shape(&packed_shape));
|
||||
EXPECT_EQ(packed_shape, shape);
|
||||
TensorHandle::ResourceHandleInfo resource_handle_info;
|
||||
TF_ASSERT_OK(packed_handle->GetResourceHandleInfo(&resource_handle_info));
|
||||
EXPECT_EQ(resource_handle_info.dtypes_and_shapes.size(), 1);
|
||||
EXPECT_EQ(resource_handle_info.dtypes_and_shapes.at(0).dtype, DT_FLOAT);
|
||||
EXPECT_EQ(
|
||||
resource_handle_info.dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}),
|
||||
true);
|
||||
|
||||
CompositeDevice* device = reinterpret_cast<CompositeDevice*>(
|
||||
absl::get<Device*>(packed_handle->device()));
|
||||
EXPECT_EQ(device->name(), "/job:worker/replica:0/task:0/device:COMPOSITE:0");
|
||||
EXPECT_EQ(device->underlying_devices()->size(), 4);
|
||||
|
||||
const std::vector<TensorHandle::HandleType> expected_handle_types = {
|
||||
TensorHandle::LOCAL, TensorHandle::LOCAL, TensorHandle::REMOTE,
|
||||
TensorHandle::REMOTE};
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
TensorHandle* h = nullptr;
|
||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(i, &h));
|
||||
EXPECT_EQ(absl::get<Device*>(h->device()), ListDevices().at(i));
|
||||
EXPECT_EQ(h->Type(), expected_handle_types.at(i));
|
||||
}
|
||||
EXPECT_FALSE(IsReady(packed_handle));
|
||||
|
||||
TF_ASSERT_OK(h2->SetRemoteShape(shape, ListDevices().at(2),
|
||||
context()->GetContextViewId()));
|
||||
EXPECT_FALSE(IsReady(packed_handle));
|
||||
TF_ASSERT_OK(h3->SetRemoteShape(shape, ListDevices().at(3),
|
||||
context()->GetContextViewId()));
|
||||
EXPECT_TRUE(IsReady(packed_handle));
|
||||
|
||||
packed_handle->Unref();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -38,6 +38,7 @@ class RemoteTensorHandleData {
|
||||
Status NumDims(int* num_dims) const;
|
||||
Status Dim(int dim_index, int64* dim) const;
|
||||
Status NumElements(int64* num_elements) const;
|
||||
Status Unprotect() { return Status::OK(); }
|
||||
|
||||
bool IsReady() const;
|
||||
Status SetShape(const TensorShape& shape);
|
||||
|
Loading…
Reference in New Issue
Block a user