Introduce PackedTensorHandleData to TensorHandle. A PackedTensorHandleData refers to a list of TensorHandles of the same dtype and shape.

PiperOrigin-RevId: 308702161
Change-Id: Ide047f4cde1c17e7be9e0d64f78f499a022a430e
This commit is contained in:
Yujing Zhang 2020-04-27 14:45:12 -07:00 committed by TensorFlower Gardener
parent ea74b01d80
commit 4430ba27bb
15 changed files with 490 additions and 42 deletions

View File

@ -1025,10 +1025,10 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
return t->data();
}
if (handle->IsRemote()) {
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDevicePointer may not be called on a remote tensor "
"handle.");
"TFE_TensorHandleDevicePointer may not be called on a ",
handle->TypeString(), " tensor handle.");
return nullptr;
}
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
@ -1099,10 +1099,10 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle);
if (handle->IsRemote()) {
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
"handle.");
"TFE_TensorHandleDeviceMemorySize may not be called on a ",
handle->TypeString(), " tensor handle.");
return 0;
}
const tensorflow::Tensor* tensor;

View File

@ -47,9 +47,9 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(h->handle);
if (handle->IsRemote()) {
if (handle->Type() != TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support remote tensor");
"DLPack doesn't support ", handle->TypeString(), " tensor");
return nullptr;
}
const tensorflow::Tensor* tensor;

View File

@ -30,6 +30,14 @@ std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
errors::InvalidArgument("underlying_devices should not be empty."));
return nullptr;
}
std::set<string> unique_devices;
for (const string& device : underlying_devices) {
if (!unique_devices.insert(device).second) {
status->Update(errors::InvalidArgument(
"Got a duplicated device in underlying_devices: ", device));
return nullptr;
}
}
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullName(underlying_devices.at(0), &parsed_name)) {
status->Update(tensorflow::errors::InvalidArgument(

View File

@ -47,6 +47,21 @@ TEST(CompositeDeviceTest, Basic) {
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
}
{
Status status;
underlying_devices.push_back(
"/job:localhost/replica:0/task:0/device:CPU:0");
std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
&status);
EXPECT_EQ(composite_device, nullptr);
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
EXPECT_TRUE(
absl::StrContains(status.error_message(), "Got a duplicated device"))
<< status.ToString();
underlying_devices.pop_back();
}
{
Status status;
underlying_devices.push_back(

View File

@ -72,6 +72,7 @@ tf_cuda_library(
deps = [
":eager_executor",
":kernel_and_device",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:context_interface",
"//tensorflow/c/eager:tensor_handle_interface",

View File

@ -883,6 +883,29 @@ Status EagerContext::RegisterCustomDevice(
return Status::OK();
}
Status EagerContext::FindOrCreateCompositeDevice(
const std::vector<string>& underlying_devices,
CompositeDevice** composite_device) {
const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
mutex_lock l(composite_devices_mu_);
auto iter = composite_devices_.find(hash_key);
if (iter != composite_devices_.end()) {
*composite_device = iter->second.get();
return Status::OK();
}
Status s;
auto device = CompositeDevice::MakeDevice(underlying_devices,
composite_devices_.size(), &s);
TF_RETURN_IF_ERROR(s);
*composite_device = device.get();
// TODO(b/145922293): Add the composite device to the device set of pflr in
// order to make placer recognize it.
composite_devices_.emplace(hash_key, std::move(device));
return Status::OK();
}
bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
if (first == nullptr) first = HostCPU();
if (second == nullptr) second = HostCPU();

View File

@ -32,8 +32,10 @@ limitations under the License.
// clang-format on
#include "absl/types/optional.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
@ -487,6 +489,11 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
Status RegisterCustomDevice(const string& name,
std::unique_ptr<CustomDevice> device);
// Find or create a composite device with the given `underlying_devices`.
Status FindOrCreateCompositeDevice(
const std::vector<string>& underlying_devices,
CompositeDevice** composite_device);
bool OnSameTask(const Device* first, const Device* second) const;
// Gets the CPU device on the task of device.
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
@ -569,6 +576,13 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
std::function<Rendezvous*(const int64)> rendezvous_creator_;
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
mutable mutex composite_devices_mu_;
// Maps from the fingerprint of a set of device names to a virtual
// CompositeDevice.
// TODO(b/145922293): Consider taking device names as keys.
absl::flat_hash_map<uint64, std::unique_ptr<CompositeDevice>>
composite_devices_ GUARDED_BY(composite_devices_mu_);
FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
std::unique_ptr<thread::ThreadPool> thread_pool_;

View File

@ -170,5 +170,27 @@ TEST_F(EagerContextTest, SelectDeviceExplicitSoftPlacement) {
EXPECT_EQ(dev->device_type(), DEVICE_CPU);
}
TEST_F(EagerContextTest, CompositeDevice) {
InitContext(SessionOptions(), DEVICE_PLACEMENT_EXPLICIT);
std::vector<string> underlying_devices = {
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:0/device:CPU:1"};
CompositeDevice* composite_device_0 = nullptr;
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
&composite_device_0));
EXPECT_EQ(composite_device_0->name(),
"/job:worker/replica:0/task:0/device:COMPOSITE:0");
CompositeDevice* composite_device_1 = nullptr;
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
&composite_device_1));
EXPECT_EQ(composite_device_1, composite_device_0);
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:2");
CompositeDevice* composite_device_2 = nullptr;
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
&composite_device_2));
EXPECT_EQ(composite_device_2->name(),
"/job:worker/replica:0/task:0/device:COMPOSITE:1");
}
} // namespace
} // namespace tensorflow

View File

@ -48,7 +48,7 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
}
}
if (IsRemote()) {
if (Type() == REMOTE) {
const tensorflow::Tensor* t = nullptr;
TensorHandle* h_cpu = nullptr;
*status = EagerCopyToDevice(this, ctx_, &ctx_->Executor(), ctx_->HostCPU(),
@ -68,7 +68,7 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
h_cpu->Unref();
delete tf_tensor;
return retval;
} else {
} else if (Type() == LOCAL) {
tensorflow::Tensor tensor;
if (IsCPU(device()) || HasLocalMirror(nullptr)) {
const tensorflow::Tensor* src = nullptr;
@ -95,6 +95,10 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
AbstractTensorInterface* retval = tf_tensor->tensor;
delete tf_tensor;
return retval;
} else {
*status = errors::InvalidArgument(
"Resolve() is not supoorted on packed TensorHandles.");
return nullptr;
}
}

View File

@ -200,7 +200,8 @@ Status ValidateInputTypeAndPlacement(
"together.");
}
Device* handle_device = absl::get<Device*>(handle_device_variant);
const bool maybe_copy = !skip_remote_copy || !handle->IsRemote();
const bool maybe_copy =
!skip_remote_copy || handle->Type() != TensorHandle::REMOTE;
// If the input is already on the right device, then nothing to do.
if (expected_device != handle_device && maybe_copy) {
TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(),
@ -258,7 +259,7 @@ Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
}
Device* cpu_device = ctx.HostCPU();
string device_name;
if (tensor_handle->IsRemote()) {
if (tensor_handle->Type() != TensorHandle::LOCAL) {
Device* device = absl::get<Device*>(tensor_handle->device());
device_name = device != nullptr ? device->name() : cpu_device->name();
*result = (device == nullptr ? cpu_device : device);
@ -392,7 +393,8 @@ Status GetOrCreateKernelAndDevice(
// which doesn't accept remote inputs.
for (int i = 0; i < op->Inputs().size(); i++) {
TensorHandle* input = op->Inputs()[i];
if (!ctx.LazyCopyFunctionRemoteInputs() && input->IsRemote()) {
if (!ctx.LazyCopyFunctionRemoteInputs() &&
input->Type() == TensorHandle::REMOTE) {
TensorHandle* handle = nullptr;
TF_RETURN_IF_ERROR(
EagerCopyToDevice(input, &ctx, &op->Executor(),

View File

@ -36,7 +36,8 @@ Status ExecuteNodeArgs::Init(
if (!s.ok()) {
#if !defined(IS_MOBILE_PLATFORM)
uint64 context_view_id = ctx->GetContextViewId();
if (in->IsRemote() || in->HasRemoteMirror(d, context_view_id)) {
if (in->Type() == TensorHandle::REMOTE ||
in->HasRemoteMirror(d, context_view_id)) {
if (!has_remote_inputs_) {
has_remote_inputs_ = true;
}

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/types/variant.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
@ -47,6 +48,91 @@ limitations under the License.
namespace tensorflow {
TensorHandle::PackedTensorHandleData::PackedTensorHandleData(
std::vector<TensorHandle*>&& handles, const TensorShape& shape)
: handles_(std::move(handles)), shape_(shape) {
for (auto* handle : handles_) {
handle->Ref();
}
}
TensorHandle::PackedTensorHandleData::~PackedTensorHandleData() {
for (auto* handle : handles_) {
handle->Unref();
}
}
Status TensorHandle::PackedTensorHandleData::Shape(TensorShape* shape) const {
*shape = shape_;
return Status::OK();
}
Status TensorHandle::PackedTensorHandleData::NumDims(int* num_dims) const {
*num_dims = shape_.dims();
return Status::OK();
}
Status TensorHandle::PackedTensorHandleData::Dim(int dim_index,
int64* dim) const {
*dim = shape_.dim_size(dim_index);
return Status::OK();
}
Status TensorHandle::PackedTensorHandleData::NumElements(
int64* num_elements) const {
*num_elements = shape_.num_elements();
return Status::OK();
}
Status TensorHandle::PackedTensorHandleData::Unprotect() {
for (auto* handle : handles_) {
TF_RETURN_IF_ERROR(absl::visit([](auto& data) { return data.Unprotect(); },
handle->data_));
}
return Status::OK();
}
bool TensorHandle::PackedTensorHandleData::IsReady() const {
{
tf_shared_lock l(mu_);
if (!is_poisoned_.ok()) {
return true;
}
}
for (auto* handle : handles_) {
if (!handle->IsReady()) {
return false;
}
}
return true;
}
void TensorHandle::PackedTensorHandleData::Poison(Status status) {
mutex_lock l(mu_);
is_poisoned_ = status;
}
string TensorHandle::PackedTensorHandleData::DebugString() const {
string debug_str = "PackedTensorHandleData: ";
for (const auto* handle : handles_) {
debug_str.append(
absl::StrCat(absl::visit([](auto& data) { return data.DebugString(); },
handle->data_),
"; "));
}
return debug_str;
}
Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
const int index, TensorHandle** handle) const {
if (index < 0 || index >= handles_.size()) {
return errors::InvalidArgument("Expect an index within [0, ",
handles_.size(), "), but got ", index);
}
*handle = handles_.at(index);
return Status::OK();
}
void TensorHandle::SetResourceHandleInfo(
ResourceHandleInfo&& resource_handle_info) {
resource_handle_info_ = std::move(resource_handle_info);
@ -61,7 +147,7 @@ Status TensorHandle::GetResourceHandleInfoImpl(
dtype);
}
if (IsRemote()) {
if (Type() != LOCAL) {
set_resource_info();
return Status::OK();
}
@ -98,6 +184,16 @@ Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::ExtractPackedHandle(const int index,
TensorHandle** handle) const {
if (Type() != PACKED) {
return errors::Internal("Invalid ExtractPackedHandleOnDevice call on a",
TypeString(), " handle: ", this);
}
return absl::get<PackedTensorHandleData>(data_).ExtractPackedHandle(index,
handle);
}
TensorHandle* TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t) {
// TODO(b/136608821): Move away from nullptr
tensorflow::Tensor tensor = t;
@ -195,6 +291,64 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
<< " device: " << VariantDeviceDebugString(device_);
}
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
EagerContext* ctx,
TensorHandle** packed_handle) {
if (handles.empty()) {
return errors::InvalidArgument("Handles should not be empty.");
}
// Get the dtype and shape from the fisrt handle since all handles have the
// same dtype and shape.
tensorflow::DataType dtype = handles.at(0)->dtype;
tensorflow::TensorShape shape;
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
ResourceHandleInfo resource_handle_info;
if (dtype == DT_RESOURCE) {
TF_RETURN_IF_ERROR(
handles.at(0)->GetResourceHandleInfo(&resource_handle_info));
}
std::vector<string> devices;
for (auto* handle : handles) {
if (VariantDeviceIsCustom(handle->device())) {
return errors::InvalidArgument(
"CustomDevice is not supported for packing.");
} else {
devices.push_back(
absl::get<Device*>(handle->DeviceOrHostCPU(*ctx))->name());
}
}
Device* device;
if (devices.size() == 1) {
device = absl::get<Device*>(handles.at(0)->DeviceOrHostCPU(*ctx));
} else {
CompositeDevice* composite_device = nullptr;
TF_RETURN_IF_ERROR(
ctx->FindOrCreateCompositeDevice(devices, &composite_device));
device = composite_device;
}
*packed_handle =
new TensorHandle(std::move(handles), device, dtype, shape, ctx);
(*packed_handle)->SetResourceHandleInfo(std::move(resource_handle_info));
return Status::OK();
}
TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
const tensorflow::DataType dtype,
const tensorflow::TensorShape& shape,
EagerContext* ctx)
: dtype(dtype),
device_(device),
op_device_(device),
resource_device_(dtype == DT_RESOURCE ? device : nullptr),
ctx_(ctx),
data_(absl::in_place_type<PackedTensorHandleData>, std::move(handles),
shape) {
DVLOG(3) << "Creating a packed TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
}
#if !defined(IS_MOBILE_PLATFORM)
TensorHandle* TensorHandle::CreateUnshapedRemoteHandle(
int64 op_id, int32 output_num, const string& remote_task,
@ -253,19 +407,32 @@ bool TensorHandle::IsReady() const {
return absl::visit([](auto& data) { return data.IsReady(); }, data_);
}
bool TensorHandle::IsRemote() const {
#if !defined(IS_MOBILE_PLATFORM)
return data_.index() == 1;
#else
return false;
#endif
TensorHandle::HandleType TensorHandle::Type() const {
if (data_.index() == 0) {
return LOCAL;
} else if (data_.index() == 1) {
return PACKED;
} else {
return REMOTE;
}
}
string TensorHandle::TypeString() const {
if (data_.index() == 0) {
return "LOCAL";
} else if (data_.index() == 1) {
return "PACKED";
} else {
return "REMOTE";
}
}
Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
DVLOG(3) << "Tensor on TensorHandle: " << this;
if (IsRemote()) {
return errors::Internal("Invalid Tensor call on remote handle: ", this);
if (Type() != LOCAL) {
return errors::Internal("Invalid Tensor call on a ", TypeString(),
" handle: ", this);
}
auto& data = absl::get<LocalTensorHandleData>(data_);
@ -277,8 +444,9 @@ Status TensorHandle::TensorFromDevice(const Device* d,
DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) {
if (IsRemote()) {
return errors::Internal("Invalid Tensor call on remote handle: ", this);
if (Type() != LOCAL) {
return errors::Internal("Invalid Tensor call on a ", TypeString(),
" handle: ", this);
}
auto& data = absl::get<LocalTensorHandleData>(data_);
@ -306,9 +474,9 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
VariantDeviceDebugString(device_),
", requested device: ", d != nullptr ? d->name() : "(nil)");
} else if (d == absl::get<Device*>(device_)) {
if (IsRemote()) {
return errors::Internal("Invalid TensorValue call on remote handle: ",
this);
if (Type() != LOCAL) {
return errors::Internal("Invalid TensorValue call on a ", TypeString(),
" handle: ", this);
}
auto& data = absl::get<LocalTensorHandleData>(data_);
@ -449,9 +617,8 @@ Status TensorHandle::NumElements(int64* num_elements) const {
Status TensorHandle::Unprotect(const Device* d) {
DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
if (!IsRemote() && (d == absl::get<Device*>(device_))) {
auto& data = absl::get<LocalTensorHandleData>(data_);
return data.Unprotect();
if (d == absl::get<Device*>(device_)) {
return absl::visit([](auto& data) { return data.Unprotect(); }, data_);
}
tf_shared_lock l(mu_);
@ -511,7 +678,7 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
"Could not find remote mirror for specified device");
}
if (!IsRemote()) {
if (Type() != REMOTE) {
return errors::InvalidArgument("Primary device is not remote");
}
@ -623,7 +790,8 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
return Status::OK();
}
DCHECK(IsRemote()) << "SetRemoteShape is only called on remote handles.";
DCHECK(Type() == REMOTE)
<< "SetRemoteShape is only called on remote handles.";
auto& data = absl::get<RemoteTensorHandleData>(data_);
// context_view_id is currently used to validate mirrors. The shape of
@ -643,7 +811,8 @@ void TensorHandle::PoisonRemote(Status status, const Device* d,
<< " " << d->name();
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
DCHECK(IsRemote()) << "Poison can only be on remote handles: " << this;
DCHECK(Type() == REMOTE)
<< "Poison can only be on remote handles: " << this;
auto& data = absl::get<RemoteTensorHandleData>(data_);
data.Poison(status);
@ -681,7 +850,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) {
DCHECK(!IsRemote()) << "SetTensor is not called on remote handles.";
DCHECK(Type() == LOCAL) << "SetTensor is not called on local handles.";
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
auto& resource_handle = t.flat<class ResourceHandle>()(0);
@ -709,10 +878,8 @@ void TensorHandle::Poison(Status status, const Device* d) {
DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
DCHECK(!IsRemote()) << "Poison can only be on local handles: " << this;
auto& data = absl::get<LocalTensorHandleData>(data_);
data.Poison(status);
DCHECK(Type() != REMOTE) << "Poison can only be on local handles: " << this;
absl::visit([status](auto& data) { data.Poison(status); }, data_);
} else {
tf_shared_lock l(mu_);
auto elem = local_mirrors_.find(d);

View File

@ -87,6 +87,14 @@ class TensorHandle : public AbstractTensorHandleInterface,
Device* resource_device,
tensorflow::DataType dtype,
EagerContext* ctx);
// Create a handle which packs the given handles of the same dtype and shape.
// If handles are on different devices, assign the packed handle to a
// CompositeDevice.
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
EagerContext* ctx,
TensorHandle** packed_handle);
#if !defined(IS_MOBILE_PLATFORM)
static TensorHandle* CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
const string& remote_task,
@ -200,7 +208,10 @@ class TensorHandle : public AbstractTensorHandleInterface,
// ready.
const tensorflow::DataType dtype;
bool IsRemote() const;
enum HandleType { LOCAL = 0, PACKED = 1, REMOTE = 2 };
HandleType Type() const;
string TypeString() const;
string DebugString() const;
@ -218,7 +229,17 @@ class TensorHandle : public AbstractTensorHandleInterface,
std::vector<DtypeAndPartialTensorShape>* result);
Status GetResourceAllowedDevices(std::vector<string>* result);
// It's called on a packed TensorHandle. Extract a handle with the given
// index.
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
private:
friend class PackedTensorHandleTest;
TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
const tensorflow::DataType dtype,
const tensorflow::TensorShape& shape, EagerContext* ctx);
~TensorHandle() override;
// The TensorHandleData can either represent a local or remote tensor handle.
@ -275,12 +296,43 @@ class TensorHandle : public AbstractTensorHandleInterface,
// devices for the underlying resource.
ResourceHandleInfo resource_handle_info_;
// A handle data which refers to multiple TensorHandles of the same dtype and
// shape.
class PackedTensorHandleData {
public:
PackedTensorHandleData(std::vector<TensorHandle*>&& handles,
const TensorShape& shape);
~PackedTensorHandleData();
Status Shape(TensorShape* shape) const;
Status NumDims(int* num_dims) const;
Status Dim(int dim_index, int64* dim) const;
Status NumElements(int64* num_elements) const;
Status Unprotect();
bool IsReady() const;
void Poison(Status status);
string DebugString() const;
// Extract a handle on the given index.
Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
private:
const std::vector<TensorHandle*> handles_;
const TensorShape shape_;
mutable mutex mu_;
Status is_poisoned_ TF_GUARDED_BY(mu_);
};
// Does not need synchronization because it can be accessed only after
// WaitReady() has returned. At that point, data_ is immutable.
#if !defined(IS_MOBILE_PLATFORM)
absl::variant<LocalTensorHandleData, RemoteTensorHandleData> data_;
absl::variant<LocalTensorHandleData, PackedTensorHandleData,
RemoteTensorHandleData>
data_;
#else
absl::variant<LocalTensorHandleData> data_;
absl::variant<LocalTensorHandleData, PackedTensorHandleData> data_;
#endif
PartialTensorShape inference_shape_;

View File

@ -15,8 +15,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -64,5 +66,141 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
ctx->Unref();
}
static Device* CreateDevice(const char* type, const char* name) {
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
};
DeviceAttributes attr;
attr.set_name(name);
attr.set_device_type(type);
return new FakeDevice(attr);
}
} // namespace
class PackedTensorHandleTest : public ::testing::Test {
public:
PackedTensorHandleTest() {
std::vector<std::unique_ptr<Device>> devices;
for (const char* name : device_names_) {
devices.emplace_back(CreateDevice("GPU", name));
}
device_mgr_ = new StaticDeviceMgr(std::move(devices));
context_ = new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr_,
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr);
}
~PackedTensorHandleTest() override {
delete device_mgr_;
context_->Unref();
}
EagerContext* context() { return context_; }
std::vector<Device*> ListDevices() const {
return device_mgr_->ListDevices();
}
bool IsReady(TensorHandle* handle) const { return handle->IsReady(); }
private:
const std::vector<const char*> device_names_ = {
"/job:worker/replica:0/task:0/device:GPU:0",
"/job:worker/replica:0/task:0/device:GPU:1",
"/job:worker/replica:0/task:1/device:GPU:0",
"/job:worker/replica:0/task:1/device:GPU:1"};
StaticDeviceMgr* device_mgr_;
EagerContext* context_;
};
TEST_F(PackedTensorHandleTest, PackedHandle) {
tensorflow::DataType dtype = DT_RESOURCE;
TensorShape shape = {};
DtypeAndPartialTensorShape dtype_and_shape = {DT_FLOAT, {2, 2}};
// Create 2 local TensorHandles (ready)
std::vector<TensorHandle*> handles;
Tensor t0(dtype, shape);
Device* d0 = ListDevices().at(0);
TensorHandle* h0 =
TensorHandle::CreateLocalHandle(std::move(t0), d0, d0, d0, context());
h0->SetResourceHandleInfo({{dtype_and_shape}, {}});
handles.push_back(h0);
Tensor t1(dtype, shape);
Device* d1 = ListDevices().at(1);
TensorHandle* h1 =
TensorHandle::CreateLocalHandle(std::move(t1), d1, d1, d1, context());
h1->SetResourceHandleInfo({{dtype_and_shape}, {}});
handles.push_back(h1);
// Create 2 remote TensorHandles (not ready).
const string remote_task = "/job:worker/replica:0/task:1";
Device* d2 = ListDevices().at(2);
TensorHandle* h2 = TensorHandle::CreateUnshapedRemoteHandle(
/*op_id=*/0, /*output_num=*/0, remote_task, dtype, d2, context());
handles.push_back(h2);
Device* d3 = ListDevices().at(3);
TensorHandle* h3 = TensorHandle::CreateUnshapedRemoteHandle(
/*op_id=*/1, /*output_num=*/0, remote_task, dtype, d3, context());
handles.push_back(h3);
TensorHandle* packed_handle = nullptr;
TF_EXPECT_OK(TensorHandle::CreatePackedHandle(std::move(handles), context(),
&packed_handle));
h0->Unref();
h1->Unref();
h2->Unref();
h3->Unref();
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
EXPECT_EQ(packed_handle->dtype, dtype);
TensorShape packed_shape;
TF_ASSERT_OK(packed_handle->Shape(&packed_shape));
EXPECT_EQ(packed_shape, shape);
TensorHandle::ResourceHandleInfo resource_handle_info;
TF_ASSERT_OK(packed_handle->GetResourceHandleInfo(&resource_handle_info));
EXPECT_EQ(resource_handle_info.dtypes_and_shapes.size(), 1);
EXPECT_EQ(resource_handle_info.dtypes_and_shapes.at(0).dtype, DT_FLOAT);
EXPECT_EQ(
resource_handle_info.dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}),
true);
CompositeDevice* device = reinterpret_cast<CompositeDevice*>(
absl::get<Device*>(packed_handle->device()));
EXPECT_EQ(device->name(), "/job:worker/replica:0/task:0/device:COMPOSITE:0");
EXPECT_EQ(device->underlying_devices()->size(), 4);
const std::vector<TensorHandle::HandleType> expected_handle_types = {
TensorHandle::LOCAL, TensorHandle::LOCAL, TensorHandle::REMOTE,
TensorHandle::REMOTE};
for (int i = 0; i < 4; ++i) {
TensorHandle* h = nullptr;
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(i, &h));
EXPECT_EQ(absl::get<Device*>(h->device()), ListDevices().at(i));
EXPECT_EQ(h->Type(), expected_handle_types.at(i));
}
EXPECT_FALSE(IsReady(packed_handle));
TF_ASSERT_OK(h2->SetRemoteShape(shape, ListDevices().at(2),
context()->GetContextViewId()));
EXPECT_FALSE(IsReady(packed_handle));
TF_ASSERT_OK(h3->SetRemoteShape(shape, ListDevices().at(3),
context()->GetContextViewId()));
EXPECT_TRUE(IsReady(packed_handle));
packed_handle->Unref();
}
} // namespace tensorflow

View File

@ -38,6 +38,7 @@ class RemoteTensorHandleData {
Status NumDims(int* num_dims) const;
Status Dim(int dim_index, int64* dim) const;
Status NumElements(int64* num_elements) const;
Status Unprotect() { return Status::OK(); }
bool IsReady() const;
Status SetShape(const TensorShape& shape);