From 1c74b32aa27dc0d40a9ce1f883ea632d399a7b9a Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Tue, 12 May 2020 21:23:08 -0700 Subject: [PATCH] Validate remote resource devices before safe access of resources. Cluster updates (due to recreated distribution strategies, remote worker failures, etc.) can lead to crashing failures with segfaults when accessing resources created before the update. Some common patterns are: * Accessing datasets created on old remote workers; * Accessing variables created on failed workers; * Garbage collecting datasets/iterators created on old remote workers; This CL validate the remote devices to make sure the access is safe before executing the ops by looking up the device in a set of device pointers and checking its incarnation ID. Remote workers on restarted devices will have different incarnation IDs, and accessing resources on those devices will fail gracefully. PiperOrigin-RevId: 311261000 Change-Id: Ifc07862229b06301e0275fe80975565d9df28152 --- tensorflow/c/eager/c_api_cluster_test.cc | 120 ++++++++++++++++++ tensorflow/c/eager/c_api_test.cc | 2 + tensorflow/c/eager/c_api_test_util.cc | 1 + tensorflow/core/common_runtime/device_mgr.cc | 5 + tensorflow/core/common_runtime/device_mgr.h | 10 ++ .../core/common_runtime/dynamic_device_mgr.cc | 7 + .../core/common_runtime/eager/execute.cc | 17 +++ .../common_runtime/eager/tensor_handle.cc | 20 +++ .../core/common_runtime/eager/tensor_handle.h | 6 + .../eager/tensor_handle_test.cc | 101 ++++++++++++++- 10 files changed, 286 insertions(+), 3 deletions(-) diff --git a/tensorflow/c/eager/c_api_cluster_test.cc b/tensorflow/c/eager/c_api_cluster_test.cc index 8f585d6f02c..252a0408758 100644 --- a/tensorflow/c/eager/c_api_cluster_test.cc +++ b/tensorflow/c/eager/c_api_cluster_test.cc @@ -50,6 +50,13 @@ tensorflow::ServerDef GetServerDef(int num_tasks) { return GetServerDef("localhost", num_tasks); } +void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) { + tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0); + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->at(task_index) = + tensorflow::strings::StrCat("localhost:", port); +} + void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, const std::vector& expected_values) { std::unique_ptr status( @@ -101,6 +108,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, TF_DeleteStatus(status); } +// Read the value of variable `var` and save it into `out_value`. +void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var, + TFE_TensorHandle** out_value) { + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + int num_retvals = 1; + TFE_Execute(op, out_value, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + TF_DeleteStatus(status); +} + void TestRemoteExecuteChangeServerDef(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); @@ -243,6 +266,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) { TestRemoteExecuteUpdateServerDef(true); } +void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0"; + const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + + TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name); + EXPECT_NE(var_handle0, nullptr); + TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name); + EXPECT_NE(var_handle1, nullptr); + + TFE_TensorHandle* value_handle = nullptr; + ReadVariable(ctx, var_handle1, &value_handle); + CheckTFE_TensorHandleHasFloats(value_handle, {2}); + TFE_DeleteTensorHandle(value_handle); + + // Start a new worker to replace task:1 + ReplaceTaskInServerDef(&server_def, 1); + server_def.set_task_index(1); + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + // Update server def to replace the remote device with the device info on the + // new worker (different incarnation ID). + server_def.set_task_index(0); + string serialized_update = server_def.SerializeAsString(); + TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), + serialized_update.size(), status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // The device of var_handle0 is local device which is the same before and + // after cluster update. Remove resource with valid device should succeed. + TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, var_handle0, status); + TFE_OpSetDevice(op, dev0_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + int num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + // The device of var_handle1 is remote device, which was replaced during + // cluster update. Removing resource with invalid device should fail + // gracefully (i.e., with error status) instead of crashing with segfaults. + op = TFE_NewOp(ctx, "DestroyResourceOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, var_handle1, status); + TFE_OpSetDevice(op, dev1_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + TFE_DeleteTensorHandle(var_handle0); + TFE_DeleteTensorHandle(var_handle1); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) { + TestRemoteExecuteUpdateServerDefResourceAccess(false); +} + +TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) { + TestRemoteExecuteUpdateServerDefResourceAccess(true); +} + void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { // Fail fast on GetStatus requests so we can get errors instead of timeout // when updating cluster with non-exsitent worker @@ -282,6 +401,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( {2, tensorflow::strings::StrCat("localhost:", port)}); + server_def.set_task_index(0); string serialized_update = server_def.SerializeAsString(); TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), serialized_update.size(), status); diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 548bf1337bb..724176505ba 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1203,6 +1203,8 @@ void BM_ReadVariable(int iters) { CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); h = nullptr; + TFE_OpAddInput(op, var_handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } tensorflow::testing::StopTiming(); TFE_DeleteOp(op); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index bbdc4c8f410..29b624b8537 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -150,6 +150,7 @@ TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value, TFE_TensorHandle* var_handle = nullptr; int num_retvals = 1; TFE_Execute(op, &var_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; TFE_DeleteOp(op); if (TF_GetCode(status) != TF_OK) return nullptr; CHECK_EQ(1, num_retvals); diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index c7583c374f2..0b693085da3 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -45,6 +45,7 @@ StaticDeviceMgr::StaticDeviceMgr(std::vector> devices) } const auto& t = d->device_type(); device_type_counts_[t]++; + device_incarnation_set_.insert(d->attributes().incarnation()); if (cpu_device_ == nullptr && t == "CPU" && d->parsed_name().id == 0) { cpu_device_ = d.get(); } @@ -123,6 +124,10 @@ Status StaticDeviceMgr::LookupDevice(StringPiece name, Device** device) const { return Status::OK(); } +bool StaticDeviceMgr::ContainsDevice(int64 device_incarnation) const { + return device_incarnation_set_.contains(device_incarnation); +} + void StaticDeviceMgr::ClearContainers( gtl::ArraySlice containers) const { Status s; diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h index 56248b39078..83a0d0cc29c 100644 --- a/tensorflow/core/common_runtime/device_mgr.h +++ b/tensorflow/core/common_runtime/device_mgr.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/lib/core/arena.h" #include "tensorflow/core/lib/core/status.h" @@ -56,6 +57,11 @@ class DeviceMgr { // Accepts either a full device name, or just the replica-local suffix. virtual Status LookupDevice(StringPiece name, Device** device) const = 0; + // Check if the current device manager contains device with the given + // incarnation ID. Looking up by incarnation IDs because they are randomly + // generated and not intentionally reused (unlike device pointers). + virtual bool ContainsDevice(int64 device_incarnation) const = 0; + // Clears given containers of all devices if 'container' is // non-empty. Otherwise, clears default containers of all devices. virtual void ClearContainers(gtl::ArraySlice containers) const = 0; @@ -86,6 +92,7 @@ class StaticDeviceMgr : public DeviceMgr { string DebugString() const override; string DeviceMappingString() const override; Status LookupDevice(StringPiece name, Device** device) const override; + bool ContainsDevice(int64 device_incarnation) const override; void ClearContainers(gtl::ArraySlice containers) const override; int NumDeviceType(const string& type) const override; Device* HostCPU() const override; @@ -95,6 +102,7 @@ class StaticDeviceMgr : public DeviceMgr { StringPiece CopyToBackingStore(StringPiece s); + absl::flat_hash_set device_incarnation_set_; std::unordered_map device_map_; core::Arena name_backing_store_; // Storage for keys in device_map_ std::unordered_map device_type_counts_; @@ -117,6 +125,7 @@ class DynamicDeviceMgr : public DeviceMgr { string DebugString() const override; string DeviceMappingString() const override; Status LookupDevice(StringPiece name, Device** device) const override; + bool ContainsDevice(int64 device_incarnation) const override; void ClearContainers(gtl::ArraySlice containers) const override; int NumDeviceType(const string& type) const override; Device* HostCPU() const override; @@ -140,6 +149,7 @@ class DynamicDeviceMgr : public DeviceMgr { std::unordered_map> dynamic_devices_ TF_GUARDED_BY(devices_mu_); + absl::flat_hash_set device_incarnation_set_ TF_GUARDED_BY(devices_mu_); std::unordered_map device_map_ TF_GUARDED_BY(devices_mu_); std::unordered_map device_type_counts_ diff --git a/tensorflow/core/common_runtime/dynamic_device_mgr.cc b/tensorflow/core/common_runtime/dynamic_device_mgr.cc index f35fa7e416a..f47de47c5b9 100644 --- a/tensorflow/core/common_runtime/dynamic_device_mgr.cc +++ b/tensorflow/core/common_runtime/dynamic_device_mgr.cc @@ -92,6 +92,11 @@ Status DynamicDeviceMgr::LookupDevice(StringPiece name, Device** device) const { return Status::OK(); } +bool DynamicDeviceMgr::ContainsDevice(int64 device_incarnation) const { + tf_shared_lock l(devices_mu_); + return device_incarnation_set_.contains(device_incarnation); +} + void DynamicDeviceMgr::ClearContainers( gtl::ArraySlice containers) const { Status s; @@ -138,6 +143,7 @@ Status DynamicDeviceMgr::AddDevices( device_map_[name] = d.get(); } device_type_counts_[d->device_type()]++; + device_incarnation_set_.insert(d->attributes().incarnation()); dynamic_devices_.emplace(d.get(), std::move(d)); } return Status::OK(); @@ -171,6 +177,7 @@ Status DynamicDeviceMgr::RemoveDevices(std::vector devices) { device_map_.erase(name); } device_type_counts_[d->device_type()]--; + device_incarnation_set_.erase(d->attributes().incarnation()); dynamic_devices_.erase(it); } return Status::OK(); diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 35dd9990054..3036e6d7989 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -874,6 +874,19 @@ bool IsPinnableOp(const string& op_type) { !absl::StartsWith(op_type, "XRT"); } +// Validate if the remote device with the given incarnation is valid in the +// remote device manager of the current eager context. +Status ValidateTensorHandleRemoteDevice(EagerContext* ctx, + int64 device_incarnation) { + if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) { + return Status::OK(); + } + return errors::InvalidArgument( + "Resource input tensor contains an invalid device. This might happen " + "when the client has connected to a different cluster, or some remote " + "workers have been restarted."); +} + // The Op device may be updated if: // - A resource touching input is specified: all resource-touching ops run in // the device the resource is, regardless of anything else that has been @@ -935,6 +948,10 @@ Status MaybeUpdateOpDevice(EagerOperation* op) { for (int i = 0; i < op->Inputs().size(); ++i) { TensorHandle* tensor_handle = op->Inputs()[i]; if (tensor_handle->dtype == DT_RESOURCE) { + if (tensor_handle->resource_remote_device_incarnation() != 0) { + TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice( + &ctx, tensor_handle->resource_remote_device_incarnation())); + } Device* resource_device = tensor_handle->resource_device(); DVLOG(2) << "for op " << op->Name() << " input " << i << " " << DataTypeString(tensor_handle->dtype) diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index eef46b691ce..dfe3e4a1426 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -49,6 +49,13 @@ limitations under the License. namespace tensorflow { +namespace { +int64 GetRemoteDeviceIncarnation(Device* device) { + if (device == nullptr || device->IsLocal()) return 0; + return device->attributes().incarnation(); +} +} // namespace + TensorHandle::PackedTensorHandleData::PackedTensorHandleData( std::vector&& handles, const TensorShape& shape) : handles_(std::move(handles)), shape_(shape) { @@ -244,6 +251,8 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, device_((!ctx || d == ctx->HostCPU()) ? nullptr : d), op_device_(op_device), resource_device_(resource_device), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), data_(absl::in_place_type, std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this @@ -258,6 +267,8 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, op_device_(op_device), resource_device_( GetResourceDevice(t.flat()(0), ctx)), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), resource_handle_info_( {t.flat()(0).dtypes_and_shapes(), @@ -274,6 +285,7 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, device_(d), op_device_(nullptr), resource_device_(nullptr), + resource_remote_device_incarnation_(0), ctx_(ctx), data_(absl::in_place_type, std::move(t)) { // TODO(allenl): Figure out a better op_device story for custom devices, @@ -297,6 +309,8 @@ TensorHandle::TensorHandle(Device* d, Device* op_device, device_((d == ctx->HostCPU()) ? nullptr : d), op_device_(op_device), resource_device_(resource_device), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), data_(absl::in_place_type) { DVLOG(3) << "Creating empty Local TensorHandle: " << this @@ -354,6 +368,8 @@ TensorHandle::TensorHandle(std::vector&& handles, Device* device, device_(device), op_device_(device), resource_device_(dtype == DT_RESOURCE ? device : nullptr), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), data_(absl::in_place_type, std::move(handles), shape) { @@ -376,6 +392,8 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num, device_(d), op_device_(d), resource_device_(dtype == DT_RESOURCE ? d : nullptr), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), data_(absl::in_place_type, op_id, output_num, remote_task, ctx) { @@ -398,6 +416,8 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num, device_(d), op_device_(d), resource_device_(dtype == DT_RESOURCE ? d : nullptr), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), data_(absl::in_place_type, op_id, output_num, ctx->GetContextViewId()) { diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 45e7a3815a8..25d7fea3200 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -133,6 +133,9 @@ class TensorHandle : public AbstractTensorHandleInterface, VariantDevice device() const { return device_; } Device* op_device() const { return op_device_; } Device* resource_device() const { return resource_device_; } + int64 resource_remote_device_incarnation() const { + return resource_remote_device_incarnation_; + } VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const; @@ -265,6 +268,9 @@ class TensorHandle : public AbstractTensorHandleInterface, // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device // backing the resource. Else resource_device_ is nullptr. tensorflow::Device* const resource_device_; + // Incarnation ID of the resource device if it locates on a remote device, or + // 0 if it locates on a local device. + const int64 resource_remote_device_incarnation_; mutable mutex mu_; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc index 2bcde7dce5b..779158375de 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -66,17 +67,28 @@ TEST(TensorHandle_ShapeTest, AsyncShape) { ctx->Unref(); } -static Device* CreateDevice(const char* type, const char* name) { +static Device* CreateDevice(const char* type, const char* name, + bool is_local = true) { class FakeDevice : public Device { public: - explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} + explicit FakeDevice(const DeviceAttributes& attr, bool is_local) + : Device(nullptr, attr), is_local_(is_local) {} Status Sync() override { return Status::OK(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } + bool IsLocal() const override { return is_local_; } + + private: + const bool is_local_; }; DeviceAttributes attr; attr.set_name(name); attr.set_device_type(type); - return new FakeDevice(attr); + int64 incarnation = random::New64(); + while (incarnation == 0) { + incarnation = random::New64(); + } + attr.set_incarnation(incarnation); + return new FakeDevice(attr, is_local); } } // namespace @@ -204,4 +216,87 @@ TEST_F(PackedTensorHandleTest, PackedHandle) { packed_handle->Unref(); } +TEST(TensorHandle_ResourceDeviceTest, OnLocalDevice) { + std::unique_ptr d0( + CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0")); + StaticDeviceMgr local_device_mgr(std::move(d0)); + auto ctx = new EagerContext( + SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false, + &local_device_mgr, false, nullptr, nullptr, nullptr); + + tensorflow::DataType dtype = DT_RESOURCE; + TensorShape shape = {2}; + Tensor t(dtype, shape); + + Device* d = local_device_mgr.ListDevices()[0]; + TensorHandle* th = + TensorHandle::CreateLocalHandle(std::move(t), d, d, d, ctx); + // Remote device incarnation for local resource should be 0 (invalid) + EXPECT_EQ(0, th->resource_remote_device_incarnation()); + // Local device manager must contain the resource device. + EXPECT_TRUE(local_device_mgr.ContainsDevice( + th->resource_device()->attributes().incarnation())); + + std::unique_ptr d1( + CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0")); + StaticDeviceMgr new_device_mgr(std::move(d1)); + EXPECT_FALSE(new_device_mgr.ContainsDevice( + th->resource_device()->attributes().incarnation())); + + th->Unref(); + ctx->Unref(); +} + +TEST(TensorHandle_ResourceDeviceTest, OnRemoteDevice) { + std::unique_ptr d_local( + CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0")); + StaticDeviceMgr local_device_mgr(std::move(d_local)); + auto ctx = new EagerContext( + SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false, + &local_device_mgr, false, nullptr, nullptr, nullptr); + + std::unique_ptr d0( + CreateDevice("CPU", "/job:worker/task:0/device:CPU:0", false)); + Device* d0_ptr = d0.get(); + std::unique_ptr d1( + CreateDevice("CPU", "/job:worker/task:1/device:CPU:0", false)); + Device* d1_ptr = d1.get(); + + DynamicDeviceMgr remote_device_mgr; + std::vector> vector_d0; + vector_d0.emplace_back(std::move(d0)); + TF_ASSERT_OK(remote_device_mgr.AddDevices(std::move(vector_d0))); + + TensorHandle* th0 = TensorHandle::CreateUnshapedRemoteHandle( + 0, 0, "", DT_RESOURCE, d0_ptr, ctx); + EXPECT_TRUE(remote_device_mgr.ContainsDevice( + th0->resource_remote_device_incarnation())); + + std::vector> vector_d1; + vector_d1.emplace_back(std::move(d1)); + TF_ASSERT_OK(remote_device_mgr.AddDevices(std::move(vector_d1))); + EXPECT_TRUE(remote_device_mgr.ContainsDevice( + th0->resource_remote_device_incarnation())); + + TensorHandle* th1 = TensorHandle::CreateUnshapedRemoteHandle( + 0, 0, "", DT_RESOURCE, d1_ptr, ctx); + EXPECT_TRUE(remote_device_mgr.ContainsDevice( + th1->resource_remote_device_incarnation())); + + std::vector remove_d1{d1_ptr}; + TF_ASSERT_OK(remote_device_mgr.RemoveDevices(std::move(remove_d1))); + EXPECT_FALSE(remote_device_mgr.ContainsDevice( + th1->resource_remote_device_incarnation())); + EXPECT_TRUE(remote_device_mgr.ContainsDevice( + th0->resource_remote_device_incarnation())); + + th0->Unref(); + th1->Unref(); + ctx->Unref(); +} + } // namespace tensorflow