Validate remote resource devices before safe access of resources.
Cluster updates (due to recreated distribution strategies, remote worker failures, etc.) can lead to crashing failures with segfaults when accessing resources created before the update. Some common patterns are: * Accessing datasets created on old remote workers; * Accessing variables created on failed workers; * Garbage collecting datasets/iterators created on old remote workers; This CL validate the remote devices to make sure the access is safe before executing the ops by looking up the device in a set of device pointers and checking its incarnation ID. Remote workers on restarted devices will have different incarnation IDs, and accessing resources on those devices will fail gracefully. PiperOrigin-RevId: 311261000 Change-Id: Ifc07862229b06301e0275fe80975565d9df28152
This commit is contained in:
parent
f690a054c5
commit
1c74b32aa2
@ -50,6 +50,13 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
}
|
||||
|
||||
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
|
||||
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->at(task_index) =
|
||||
tensorflow::strings::StrCat("localhost:", port);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
const std::vector<float>& expected_values) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
@ -101,6 +108,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Read the value of variable `var` and save it into `out_value`.
|
||||
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
|
||||
TFE_TensorHandle** out_value) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, out_value, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
@ -243,6 +266,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
|
||||
TestRemoteExecuteUpdateServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
|
||||
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||
EXPECT_NE(var_handle0, nullptr);
|
||||
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||
EXPECT_NE(var_handle1, nullptr);
|
||||
|
||||
TFE_TensorHandle* value_handle = nullptr;
|
||||
ReadVariable(ctx, var_handle1, &value_handle);
|
||||
CheckTFE_TensorHandleHasFloats(value_handle, {2});
|
||||
TFE_DeleteTensorHandle(value_handle);
|
||||
|
||||
// Start a new worker to replace task:1
|
||||
ReplaceTaskInServerDef(&server_def, 1);
|
||||
server_def.set_task_index(1);
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
// Update server def to replace the remote device with the device info on the
|
||||
// new worker (different incarnation ID).
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// The device of var_handle0 is local device which is the same before and
|
||||
// after cluster update. Remove resource with valid device should succeed.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle0, status);
|
||||
TFE_OpSetDevice(op, dev0_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
// The device of var_handle1 is remote device, which was replaced during
|
||||
// cluster update. Removing resource with invalid device should fail
|
||||
// gracefully (i.e., with error status) instead of crashing with segfaults.
|
||||
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle1, status);
|
||||
TFE_OpSetDevice(op, dev1_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
TFE_DeleteTensorHandle(var_handle0);
|
||||
TFE_DeleteTensorHandle(var_handle1);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
@ -282,6 +401,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{2, tensorflow::strings::StrCat("localhost:", port)});
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
|
@ -1203,6 +1203,8 @@ void BM_ReadVariable(int iters) {
|
||||
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
h = nullptr;
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
TFE_DeleteOp(op);
|
||||
|
@ -150,6 +150,7 @@ TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
|
@ -45,6 +45,7 @@ StaticDeviceMgr::StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices)
|
||||
}
|
||||
const auto& t = d->device_type();
|
||||
device_type_counts_[t]++;
|
||||
device_incarnation_set_.insert(d->attributes().incarnation());
|
||||
if (cpu_device_ == nullptr && t == "CPU" && d->parsed_name().id == 0) {
|
||||
cpu_device_ = d.get();
|
||||
}
|
||||
@ -123,6 +124,10 @@ Status StaticDeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool StaticDeviceMgr::ContainsDevice(int64 device_incarnation) const {
|
||||
return device_incarnation_set_.contains(device_incarnation);
|
||||
}
|
||||
|
||||
void StaticDeviceMgr::ClearContainers(
|
||||
gtl::ArraySlice<string> containers) const {
|
||||
Status s;
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/lib/core/arena.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -56,6 +57,11 @@ class DeviceMgr {
|
||||
// Accepts either a full device name, or just the replica-local suffix.
|
||||
virtual Status LookupDevice(StringPiece name, Device** device) const = 0;
|
||||
|
||||
// Check if the current device manager contains device with the given
|
||||
// incarnation ID. Looking up by incarnation IDs because they are randomly
|
||||
// generated and not intentionally reused (unlike device pointers).
|
||||
virtual bool ContainsDevice(int64 device_incarnation) const = 0;
|
||||
|
||||
// Clears given containers of all devices if 'container' is
|
||||
// non-empty. Otherwise, clears default containers of all devices.
|
||||
virtual void ClearContainers(gtl::ArraySlice<string> containers) const = 0;
|
||||
@ -86,6 +92,7 @@ class StaticDeviceMgr : public DeviceMgr {
|
||||
string DebugString() const override;
|
||||
string DeviceMappingString() const override;
|
||||
Status LookupDevice(StringPiece name, Device** device) const override;
|
||||
bool ContainsDevice(int64 device_incarnation) const override;
|
||||
void ClearContainers(gtl::ArraySlice<string> containers) const override;
|
||||
int NumDeviceType(const string& type) const override;
|
||||
Device* HostCPU() const override;
|
||||
@ -95,6 +102,7 @@ class StaticDeviceMgr : public DeviceMgr {
|
||||
|
||||
StringPiece CopyToBackingStore(StringPiece s);
|
||||
|
||||
absl::flat_hash_set<int64> device_incarnation_set_;
|
||||
std::unordered_map<StringPiece, Device*, StringPieceHasher> device_map_;
|
||||
core::Arena name_backing_store_; // Storage for keys in device_map_
|
||||
std::unordered_map<string, int> device_type_counts_;
|
||||
@ -117,6 +125,7 @@ class DynamicDeviceMgr : public DeviceMgr {
|
||||
string DebugString() const override;
|
||||
string DeviceMappingString() const override;
|
||||
Status LookupDevice(StringPiece name, Device** device) const override;
|
||||
bool ContainsDevice(int64 device_incarnation) const override;
|
||||
void ClearContainers(gtl::ArraySlice<string> containers) const override;
|
||||
int NumDeviceType(const string& type) const override;
|
||||
Device* HostCPU() const override;
|
||||
@ -140,6 +149,7 @@ class DynamicDeviceMgr : public DeviceMgr {
|
||||
std::unordered_map<Device*, std::unique_ptr<Device>> dynamic_devices_
|
||||
TF_GUARDED_BY(devices_mu_);
|
||||
|
||||
absl::flat_hash_set<int64> device_incarnation_set_ TF_GUARDED_BY(devices_mu_);
|
||||
std::unordered_map<string, Device*> device_map_ TF_GUARDED_BY(devices_mu_);
|
||||
|
||||
std::unordered_map<string, int> device_type_counts_
|
||||
|
@ -92,6 +92,11 @@ Status DynamicDeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool DynamicDeviceMgr::ContainsDevice(int64 device_incarnation) const {
|
||||
tf_shared_lock l(devices_mu_);
|
||||
return device_incarnation_set_.contains(device_incarnation);
|
||||
}
|
||||
|
||||
void DynamicDeviceMgr::ClearContainers(
|
||||
gtl::ArraySlice<string> containers) const {
|
||||
Status s;
|
||||
@ -138,6 +143,7 @@ Status DynamicDeviceMgr::AddDevices(
|
||||
device_map_[name] = d.get();
|
||||
}
|
||||
device_type_counts_[d->device_type()]++;
|
||||
device_incarnation_set_.insert(d->attributes().incarnation());
|
||||
dynamic_devices_.emplace(d.get(), std::move(d));
|
||||
}
|
||||
return Status::OK();
|
||||
@ -171,6 +177,7 @@ Status DynamicDeviceMgr::RemoveDevices(std::vector<Device*> devices) {
|
||||
device_map_.erase(name);
|
||||
}
|
||||
device_type_counts_[d->device_type()]--;
|
||||
device_incarnation_set_.erase(d->attributes().incarnation());
|
||||
dynamic_devices_.erase(it);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -874,6 +874,19 @@ bool IsPinnableOp(const string& op_type) {
|
||||
!absl::StartsWith(op_type, "XRT");
|
||||
}
|
||||
|
||||
// Validate if the remote device with the given incarnation is valid in the
|
||||
// remote device manager of the current eager context.
|
||||
Status ValidateTensorHandleRemoteDevice(EagerContext* ctx,
|
||||
int64 device_incarnation) {
|
||||
if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) {
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument(
|
||||
"Resource input tensor contains an invalid device. This might happen "
|
||||
"when the client has connected to a different cluster, or some remote "
|
||||
"workers have been restarted.");
|
||||
}
|
||||
|
||||
// The Op device may be updated if:
|
||||
// - A resource touching input is specified: all resource-touching ops run in
|
||||
// the device the resource is, regardless of anything else that has been
|
||||
@ -935,6 +948,10 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
TensorHandle* tensor_handle = op->Inputs()[i];
|
||||
if (tensor_handle->dtype == DT_RESOURCE) {
|
||||
if (tensor_handle->resource_remote_device_incarnation() != 0) {
|
||||
TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
|
||||
&ctx, tensor_handle->resource_remote_device_incarnation()));
|
||||
}
|
||||
Device* resource_device = tensor_handle->resource_device();
|
||||
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
|
||||
<< DataTypeString(tensor_handle->dtype)
|
||||
|
@ -49,6 +49,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
int64 GetRemoteDeviceIncarnation(Device* device) {
|
||||
if (device == nullptr || device->IsLocal()) return 0;
|
||||
return device->attributes().incarnation();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TensorHandle::PackedTensorHandleData::PackedTensorHandleData(
|
||||
std::vector<TensorHandle*>&& handles, const TensorShape& shape)
|
||||
: handles_(std::move(handles)), shape_(shape) {
|
||||
@ -244,6 +251,8 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
||||
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
|
||||
op_device_(op_device),
|
||||
resource_device_(resource_device),
|
||||
resource_remote_device_incarnation_(
|
||||
GetRemoteDeviceIncarnation(resource_device_)),
|
||||
ctx_(ctx),
|
||||
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||
@ -258,6 +267,8 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
||||
op_device_(op_device),
|
||||
resource_device_(
|
||||
GetResourceDevice(t.flat<class ResourceHandle>()(0), ctx)),
|
||||
resource_remote_device_incarnation_(
|
||||
GetRemoteDeviceIncarnation(resource_device_)),
|
||||
ctx_(ctx),
|
||||
resource_handle_info_(
|
||||
{t.flat<class ResourceHandle>()(0).dtypes_and_shapes(),
|
||||
@ -274,6 +285,7 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d,
|
||||
device_(d),
|
||||
op_device_(nullptr),
|
||||
resource_device_(nullptr),
|
||||
resource_remote_device_incarnation_(0),
|
||||
ctx_(ctx),
|
||||
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
|
||||
// TODO(allenl): Figure out a better op_device story for custom devices,
|
||||
@ -297,6 +309,8 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
|
||||
device_((d == ctx->HostCPU()) ? nullptr : d),
|
||||
op_device_(op_device),
|
||||
resource_device_(resource_device),
|
||||
resource_remote_device_incarnation_(
|
||||
GetRemoteDeviceIncarnation(resource_device_)),
|
||||
ctx_(ctx),
|
||||
data_(absl::in_place_type<LocalTensorHandleData>) {
|
||||
DVLOG(3) << "Creating empty Local TensorHandle: " << this
|
||||
@ -354,6 +368,8 @@ TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
||||
device_(device),
|
||||
op_device_(device),
|
||||
resource_device_(dtype == DT_RESOURCE ? device : nullptr),
|
||||
resource_remote_device_incarnation_(
|
||||
GetRemoteDeviceIncarnation(resource_device_)),
|
||||
ctx_(ctx),
|
||||
data_(absl::in_place_type<PackedTensorHandleData>, std::move(handles),
|
||||
shape) {
|
||||
@ -376,6 +392,8 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
||||
device_(d),
|
||||
op_device_(d),
|
||||
resource_device_(dtype == DT_RESOURCE ? d : nullptr),
|
||||
resource_remote_device_incarnation_(
|
||||
GetRemoteDeviceIncarnation(resource_device_)),
|
||||
ctx_(ctx),
|
||||
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
|
||||
remote_task, ctx) {
|
||||
@ -398,6 +416,8 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
||||
device_(d),
|
||||
op_device_(d),
|
||||
resource_device_(dtype == DT_RESOURCE ? d : nullptr),
|
||||
resource_remote_device_incarnation_(
|
||||
GetRemoteDeviceIncarnation(resource_device_)),
|
||||
ctx_(ctx),
|
||||
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
|
||||
ctx->GetContextViewId()) {
|
||||
|
@ -133,6 +133,9 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
VariantDevice device() const { return device_; }
|
||||
Device* op_device() const { return op_device_; }
|
||||
Device* resource_device() const { return resource_device_; }
|
||||
int64 resource_remote_device_incarnation() const {
|
||||
return resource_remote_device_incarnation_;
|
||||
}
|
||||
|
||||
VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const;
|
||||
|
||||
@ -265,6 +268,9 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
||||
// If the tensor dtype is DT_RESOURCE, resource_device_ holds the device
|
||||
// backing the resource. Else resource_device_ is nullptr.
|
||||
tensorflow::Device* const resource_device_;
|
||||
// Incarnation ID of the resource device if it locates on a remote device, or
|
||||
// 0 if it locates on a local device.
|
||||
const int64 resource_remote_device_incarnation_;
|
||||
|
||||
mutable mutex mu_;
|
||||
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -66,17 +67,28 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
|
||||
ctx->Unref();
|
||||
}
|
||||
|
||||
static Device* CreateDevice(const char* type, const char* name) {
|
||||
static Device* CreateDevice(const char* type, const char* name,
|
||||
bool is_local = true) {
|
||||
class FakeDevice : public Device {
|
||||
public:
|
||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||
explicit FakeDevice(const DeviceAttributes& attr, bool is_local)
|
||||
: Device(nullptr, attr), is_local_(is_local) {}
|
||||
Status Sync() override { return Status::OK(); }
|
||||
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
||||
bool IsLocal() const override { return is_local_; }
|
||||
|
||||
private:
|
||||
const bool is_local_;
|
||||
};
|
||||
DeviceAttributes attr;
|
||||
attr.set_name(name);
|
||||
attr.set_device_type(type);
|
||||
return new FakeDevice(attr);
|
||||
int64 incarnation = random::New64();
|
||||
while (incarnation == 0) {
|
||||
incarnation = random::New64();
|
||||
}
|
||||
attr.set_incarnation(incarnation);
|
||||
return new FakeDevice(attr, is_local);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -204,4 +216,87 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
|
||||
packed_handle->Unref();
|
||||
}
|
||||
|
||||
TEST(TensorHandle_ResourceDeviceTest, OnLocalDevice) {
|
||||
std::unique_ptr<Device> d0(
|
||||
CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
StaticDeviceMgr local_device_mgr(std::move(d0));
|
||||
auto ctx = new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
|
||||
&local_device_mgr, false, nullptr, nullptr, nullptr);
|
||||
|
||||
tensorflow::DataType dtype = DT_RESOURCE;
|
||||
TensorShape shape = {2};
|
||||
Tensor t(dtype, shape);
|
||||
|
||||
Device* d = local_device_mgr.ListDevices()[0];
|
||||
TensorHandle* th =
|
||||
TensorHandle::CreateLocalHandle(std::move(t), d, d, d, ctx);
|
||||
// Remote device incarnation for local resource should be 0 (invalid)
|
||||
EXPECT_EQ(0, th->resource_remote_device_incarnation());
|
||||
// Local device manager must contain the resource device.
|
||||
EXPECT_TRUE(local_device_mgr.ContainsDevice(
|
||||
th->resource_device()->attributes().incarnation()));
|
||||
|
||||
std::unique_ptr<Device> d1(
|
||||
CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
StaticDeviceMgr new_device_mgr(std::move(d1));
|
||||
EXPECT_FALSE(new_device_mgr.ContainsDevice(
|
||||
th->resource_device()->attributes().incarnation()));
|
||||
|
||||
th->Unref();
|
||||
ctx->Unref();
|
||||
}
|
||||
|
||||
TEST(TensorHandle_ResourceDeviceTest, OnRemoteDevice) {
|
||||
std::unique_ptr<Device> d_local(
|
||||
CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
StaticDeviceMgr local_device_mgr(std::move(d_local));
|
||||
auto ctx = new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
|
||||
&local_device_mgr, false, nullptr, nullptr, nullptr);
|
||||
|
||||
std::unique_ptr<Device> d0(
|
||||
CreateDevice("CPU", "/job:worker/task:0/device:CPU:0", false));
|
||||
Device* d0_ptr = d0.get();
|
||||
std::unique_ptr<Device> d1(
|
||||
CreateDevice("CPU", "/job:worker/task:1/device:CPU:0", false));
|
||||
Device* d1_ptr = d1.get();
|
||||
|
||||
DynamicDeviceMgr remote_device_mgr;
|
||||
std::vector<std::unique_ptr<Device>> vector_d0;
|
||||
vector_d0.emplace_back(std::move(d0));
|
||||
TF_ASSERT_OK(remote_device_mgr.AddDevices(std::move(vector_d0)));
|
||||
|
||||
TensorHandle* th0 = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
0, 0, "", DT_RESOURCE, d0_ptr, ctx);
|
||||
EXPECT_TRUE(remote_device_mgr.ContainsDevice(
|
||||
th0->resource_remote_device_incarnation()));
|
||||
|
||||
std::vector<std::unique_ptr<Device>> vector_d1;
|
||||
vector_d1.emplace_back(std::move(d1));
|
||||
TF_ASSERT_OK(remote_device_mgr.AddDevices(std::move(vector_d1)));
|
||||
EXPECT_TRUE(remote_device_mgr.ContainsDevice(
|
||||
th0->resource_remote_device_incarnation()));
|
||||
|
||||
TensorHandle* th1 = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
0, 0, "", DT_RESOURCE, d1_ptr, ctx);
|
||||
EXPECT_TRUE(remote_device_mgr.ContainsDevice(
|
||||
th1->resource_remote_device_incarnation()));
|
||||
|
||||
std::vector<Device*> remove_d1{d1_ptr};
|
||||
TF_ASSERT_OK(remote_device_mgr.RemoveDevices(std::move(remove_d1)));
|
||||
EXPECT_FALSE(remote_device_mgr.ContainsDevice(
|
||||
th1->resource_remote_device_incarnation()));
|
||||
EXPECT_TRUE(remote_device_mgr.ContainsDevice(
|
||||
th0->resource_remote_device_incarnation()));
|
||||
|
||||
th0->Unref();
|
||||
th1->Unref();
|
||||
ctx->Unref();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user