diff --git a/tensorflow/c/eager/c_api_cluster_test.cc b/tensorflow/c/eager/c_api_cluster_test.cc index 8f585d6f02c..252a0408758 100644 --- a/tensorflow/c/eager/c_api_cluster_test.cc +++ b/tensorflow/c/eager/c_api_cluster_test.cc @@ -50,6 +50,13 @@ tensorflow::ServerDef GetServerDef(int num_tasks) { return GetServerDef("localhost", num_tasks); } +void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) { + tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0); + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->at(task_index) = + tensorflow::strings::StrCat("localhost:", port); +} + void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, const std::vector& expected_values) { std::unique_ptr status( @@ -101,6 +108,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, TF_DeleteStatus(status); } +// Read the value of variable `var` and save it into `out_value`. +void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var, + TFE_TensorHandle** out_value) { + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + int num_retvals = 1; + TFE_Execute(op, out_value, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + TF_DeleteStatus(status); +} + void TestRemoteExecuteChangeServerDef(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); @@ -243,6 +266,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) { TestRemoteExecuteUpdateServerDef(true); } +void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0"; + const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + + TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name); + EXPECT_NE(var_handle0, nullptr); + TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name); + EXPECT_NE(var_handle1, nullptr); + + TFE_TensorHandle* value_handle = nullptr; + ReadVariable(ctx, var_handle1, &value_handle); + CheckTFE_TensorHandleHasFloats(value_handle, {2}); + TFE_DeleteTensorHandle(value_handle); + + // Start a new worker to replace task:1 + ReplaceTaskInServerDef(&server_def, 1); + server_def.set_task_index(1); + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + // Update server def to replace the remote device with the device info on the + // new worker (different incarnation ID). + server_def.set_task_index(0); + string serialized_update = server_def.SerializeAsString(); + TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), + serialized_update.size(), status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // The device of var_handle0 is local device which is the same before and + // after cluster update. Remove resource with valid device should succeed. + TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, var_handle0, status); + TFE_OpSetDevice(op, dev0_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + int num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + // The device of var_handle1 is remote device, which was replaced during + // cluster update. Removing resource with invalid device should fail + // gracefully (i.e., with error status) instead of crashing with segfaults. + op = TFE_NewOp(ctx, "DestroyResourceOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, var_handle1, status); + TFE_OpSetDevice(op, dev1_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + TFE_DeleteTensorHandle(var_handle0); + TFE_DeleteTensorHandle(var_handle1); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) { + TestRemoteExecuteUpdateServerDefResourceAccess(false); +} + +TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) { + TestRemoteExecuteUpdateServerDefResourceAccess(true); +} + void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { // Fail fast on GetStatus requests so we can get errors instead of timeout // when updating cluster with non-exsitent worker @@ -282,6 +401,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( {2, tensorflow::strings::StrCat("localhost:", port)}); + server_def.set_task_index(0); string serialized_update = server_def.SerializeAsString(); TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), serialized_update.size(), status); diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 548bf1337bb..724176505ba 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1203,6 +1203,8 @@ void BM_ReadVariable(int iters) { CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); h = nullptr; + TFE_OpAddInput(op, var_handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } tensorflow::testing::StopTiming(); TFE_DeleteOp(op); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index bbdc4c8f410..29b624b8537 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -150,6 +150,7 @@ TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value, TFE_TensorHandle* var_handle = nullptr; int num_retvals = 1; TFE_Execute(op, &var_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; TFE_DeleteOp(op); if (TF_GetCode(status) != TF_OK) return nullptr; CHECK_EQ(1, num_retvals); diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index c7583c374f2..0b693085da3 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -45,6 +45,7 @@ StaticDeviceMgr::StaticDeviceMgr(std::vector> devices) } const auto& t = d->device_type(); device_type_counts_[t]++; + device_incarnation_set_.insert(d->attributes().incarnation()); if (cpu_device_ == nullptr && t == "CPU" && d->parsed_name().id == 0) { cpu_device_ = d.get(); } @@ -123,6 +124,10 @@ Status StaticDeviceMgr::LookupDevice(StringPiece name, Device** device) const { return Status::OK(); } +bool StaticDeviceMgr::ContainsDevice(int64 device_incarnation) const { + return device_incarnation_set_.contains(device_incarnation); +} + void StaticDeviceMgr::ClearContainers( gtl::ArraySlice containers) const { Status s; diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h index 56248b39078..83a0d0cc29c 100644 --- a/tensorflow/core/common_runtime/device_mgr.h +++ b/tensorflow/core/common_runtime/device_mgr.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/lib/core/arena.h" #include "tensorflow/core/lib/core/status.h" @@ -56,6 +57,11 @@ class DeviceMgr { // Accepts either a full device name, or just the replica-local suffix. virtual Status LookupDevice(StringPiece name, Device** device) const = 0; + // Check if the current device manager contains device with the given + // incarnation ID. Looking up by incarnation IDs because they are randomly + // generated and not intentionally reused (unlike device pointers). + virtual bool ContainsDevice(int64 device_incarnation) const = 0; + // Clears given containers of all devices if 'container' is // non-empty. Otherwise, clears default containers of all devices. virtual void ClearContainers(gtl::ArraySlice containers) const = 0; @@ -86,6 +92,7 @@ class StaticDeviceMgr : public DeviceMgr { string DebugString() const override; string DeviceMappingString() const override; Status LookupDevice(StringPiece name, Device** device) const override; + bool ContainsDevice(int64 device_incarnation) const override; void ClearContainers(gtl::ArraySlice containers) const override; int NumDeviceType(const string& type) const override; Device* HostCPU() const override; @@ -95,6 +102,7 @@ class StaticDeviceMgr : public DeviceMgr { StringPiece CopyToBackingStore(StringPiece s); + absl::flat_hash_set device_incarnation_set_; std::unordered_map device_map_; core::Arena name_backing_store_; // Storage for keys in device_map_ std::unordered_map device_type_counts_; @@ -117,6 +125,7 @@ class DynamicDeviceMgr : public DeviceMgr { string DebugString() const override; string DeviceMappingString() const override; Status LookupDevice(StringPiece name, Device** device) const override; + bool ContainsDevice(int64 device_incarnation) const override; void ClearContainers(gtl::ArraySlice containers) const override; int NumDeviceType(const string& type) const override; Device* HostCPU() const override; @@ -140,6 +149,7 @@ class DynamicDeviceMgr : public DeviceMgr { std::unordered_map> dynamic_devices_ TF_GUARDED_BY(devices_mu_); + absl::flat_hash_set device_incarnation_set_ TF_GUARDED_BY(devices_mu_); std::unordered_map device_map_ TF_GUARDED_BY(devices_mu_); std::unordered_map device_type_counts_ diff --git a/tensorflow/core/common_runtime/dynamic_device_mgr.cc b/tensorflow/core/common_runtime/dynamic_device_mgr.cc index f35fa7e416a..f47de47c5b9 100644 --- a/tensorflow/core/common_runtime/dynamic_device_mgr.cc +++ b/tensorflow/core/common_runtime/dynamic_device_mgr.cc @@ -92,6 +92,11 @@ Status DynamicDeviceMgr::LookupDevice(StringPiece name, Device** device) const { return Status::OK(); } +bool DynamicDeviceMgr::ContainsDevice(int64 device_incarnation) const { + tf_shared_lock l(devices_mu_); + return device_incarnation_set_.contains(device_incarnation); +} + void DynamicDeviceMgr::ClearContainers( gtl::ArraySlice containers) const { Status s; @@ -138,6 +143,7 @@ Status DynamicDeviceMgr::AddDevices( device_map_[name] = d.get(); } device_type_counts_[d->device_type()]++; + device_incarnation_set_.insert(d->attributes().incarnation()); dynamic_devices_.emplace(d.get(), std::move(d)); } return Status::OK(); @@ -171,6 +177,7 @@ Status DynamicDeviceMgr::RemoveDevices(std::vector devices) { device_map_.erase(name); } device_type_counts_[d->device_type()]--; + device_incarnation_set_.erase(d->attributes().incarnation()); dynamic_devices_.erase(it); } return Status::OK(); diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 35dd9990054..3036e6d7989 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -874,6 +874,19 @@ bool IsPinnableOp(const string& op_type) { !absl::StartsWith(op_type, "XRT"); } +// Validate if the remote device with the given incarnation is valid in the +// remote device manager of the current eager context. +Status ValidateTensorHandleRemoteDevice(EagerContext* ctx, + int64 device_incarnation) { + if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) { + return Status::OK(); + } + return errors::InvalidArgument( + "Resource input tensor contains an invalid device. This might happen " + "when the client has connected to a different cluster, or some remote " + "workers have been restarted."); +} + // The Op device may be updated if: // - A resource touching input is specified: all resource-touching ops run in // the device the resource is, regardless of anything else that has been @@ -935,6 +948,10 @@ Status MaybeUpdateOpDevice(EagerOperation* op) { for (int i = 0; i < op->Inputs().size(); ++i) { TensorHandle* tensor_handle = op->Inputs()[i]; if (tensor_handle->dtype == DT_RESOURCE) { + if (tensor_handle->resource_remote_device_incarnation() != 0) { + TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice( + &ctx, tensor_handle->resource_remote_device_incarnation())); + } Device* resource_device = tensor_handle->resource_device(); DVLOG(2) << "for op " << op->Name() << " input " << i << " " << DataTypeString(tensor_handle->dtype) diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index eef46b691ce..dfe3e4a1426 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -49,6 +49,13 @@ limitations under the License. namespace tensorflow { +namespace { +int64 GetRemoteDeviceIncarnation(Device* device) { + if (device == nullptr || device->IsLocal()) return 0; + return device->attributes().incarnation(); +} +} // namespace + TensorHandle::PackedTensorHandleData::PackedTensorHandleData( std::vector&& handles, const TensorShape& shape) : handles_(std::move(handles)), shape_(shape) { @@ -244,6 +251,8 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, device_((!ctx || d == ctx->HostCPU()) ? nullptr : d), op_device_(op_device), resource_device_(resource_device), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), data_(absl::in_place_type, std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this @@ -258,6 +267,8 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, op_device_(op_device), resource_device_( GetResourceDevice(t.flat()(0), ctx)), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), resource_handle_info_( {t.flat()(0).dtypes_and_shapes(), @@ -274,6 +285,7 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, device_(d), op_device_(nullptr), resource_device_(nullptr), + resource_remote_device_incarnation_(0), ctx_(ctx), data_(absl::in_place_type, std::move(t)) { // TODO(allenl): Figure out a better op_device story for custom devices, @@ -297,6 +309,8 @@ TensorHandle::TensorHandle(Device* d, Device* op_device, device_((d == ctx->HostCPU()) ? nullptr : d), op_device_(op_device), resource_device_(resource_device), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), data_(absl::in_place_type) { DVLOG(3) << "Creating empty Local TensorHandle: " << this @@ -354,6 +368,8 @@ TensorHandle::TensorHandle(std::vector&& handles, Device* device, device_(device), op_device_(device), resource_device_(dtype == DT_RESOURCE ? device : nullptr), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), data_(absl::in_place_type, std::move(handles), shape) { @@ -376,6 +392,8 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num, device_(d), op_device_(d), resource_device_(dtype == DT_RESOURCE ? d : nullptr), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), data_(absl::in_place_type, op_id, output_num, remote_task, ctx) { @@ -398,6 +416,8 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num, device_(d), op_device_(d), resource_device_(dtype == DT_RESOURCE ? d : nullptr), + resource_remote_device_incarnation_( + GetRemoteDeviceIncarnation(resource_device_)), ctx_(ctx), data_(absl::in_place_type, op_id, output_num, ctx->GetContextViewId()) { diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 45e7a3815a8..25d7fea3200 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -133,6 +133,9 @@ class TensorHandle : public AbstractTensorHandleInterface, VariantDevice device() const { return device_; } Device* op_device() const { return op_device_; } Device* resource_device() const { return resource_device_; } + int64 resource_remote_device_incarnation() const { + return resource_remote_device_incarnation_; + } VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const; @@ -265,6 +268,9 @@ class TensorHandle : public AbstractTensorHandleInterface, // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device // backing the resource. Else resource_device_ is nullptr. tensorflow::Device* const resource_device_; + // Incarnation ID of the resource device if it locates on a remote device, or + // 0 if it locates on a local device. + const int64 resource_remote_device_incarnation_; mutable mutex mu_; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc index 2bcde7dce5b..779158375de 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -66,17 +67,28 @@ TEST(TensorHandle_ShapeTest, AsyncShape) { ctx->Unref(); } -static Device* CreateDevice(const char* type, const char* name) { +static Device* CreateDevice(const char* type, const char* name, + bool is_local = true) { class FakeDevice : public Device { public: - explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} + explicit FakeDevice(const DeviceAttributes& attr, bool is_local) + : Device(nullptr, attr), is_local_(is_local) {} Status Sync() override { return Status::OK(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } + bool IsLocal() const override { return is_local_; } + + private: + const bool is_local_; }; DeviceAttributes attr; attr.set_name(name); attr.set_device_type(type); - return new FakeDevice(attr); + int64 incarnation = random::New64(); + while (incarnation == 0) { + incarnation = random::New64(); + } + attr.set_incarnation(incarnation); + return new FakeDevice(attr, is_local); } } // namespace @@ -204,4 +216,87 @@ TEST_F(PackedTensorHandleTest, PackedHandle) { packed_handle->Unref(); } +TEST(TensorHandle_ResourceDeviceTest, OnLocalDevice) { + std::unique_ptr d0( + CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0")); + StaticDeviceMgr local_device_mgr(std::move(d0)); + auto ctx = new EagerContext( + SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false, + &local_device_mgr, false, nullptr, nullptr, nullptr); + + tensorflow::DataType dtype = DT_RESOURCE; + TensorShape shape = {2}; + Tensor t(dtype, shape); + + Device* d = local_device_mgr.ListDevices()[0]; + TensorHandle* th = + TensorHandle::CreateLocalHandle(std::move(t), d, d, d, ctx); + // Remote device incarnation for local resource should be 0 (invalid) + EXPECT_EQ(0, th->resource_remote_device_incarnation()); + // Local device manager must contain the resource device. + EXPECT_TRUE(local_device_mgr.ContainsDevice( + th->resource_device()->attributes().incarnation())); + + std::unique_ptr d1( + CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0")); + StaticDeviceMgr new_device_mgr(std::move(d1)); + EXPECT_FALSE(new_device_mgr.ContainsDevice( + th->resource_device()->attributes().incarnation())); + + th->Unref(); + ctx->Unref(); +} + +TEST(TensorHandle_ResourceDeviceTest, OnRemoteDevice) { + std::unique_ptr d_local( + CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0")); + StaticDeviceMgr local_device_mgr(std::move(d_local)); + auto ctx = new EagerContext( + SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false, + &local_device_mgr, false, nullptr, nullptr, nullptr); + + std::unique_ptr d0( + CreateDevice("CPU", "/job:worker/task:0/device:CPU:0", false)); + Device* d0_ptr = d0.get(); + std::unique_ptr d1( + CreateDevice("CPU", "/job:worker/task:1/device:CPU:0", false)); + Device* d1_ptr = d1.get(); + + DynamicDeviceMgr remote_device_mgr; + std::vector> vector_d0; + vector_d0.emplace_back(std::move(d0)); + TF_ASSERT_OK(remote_device_mgr.AddDevices(std::move(vector_d0))); + + TensorHandle* th0 = TensorHandle::CreateUnshapedRemoteHandle( + 0, 0, "", DT_RESOURCE, d0_ptr, ctx); + EXPECT_TRUE(remote_device_mgr.ContainsDevice( + th0->resource_remote_device_incarnation())); + + std::vector> vector_d1; + vector_d1.emplace_back(std::move(d1)); + TF_ASSERT_OK(remote_device_mgr.AddDevices(std::move(vector_d1))); + EXPECT_TRUE(remote_device_mgr.ContainsDevice( + th0->resource_remote_device_incarnation())); + + TensorHandle* th1 = TensorHandle::CreateUnshapedRemoteHandle( + 0, 0, "", DT_RESOURCE, d1_ptr, ctx); + EXPECT_TRUE(remote_device_mgr.ContainsDevice( + th1->resource_remote_device_incarnation())); + + std::vector remove_d1{d1_ptr}; + TF_ASSERT_OK(remote_device_mgr.RemoveDevices(std::move(remove_d1))); + EXPECT_FALSE(remote_device_mgr.ContainsDevice( + th1->resource_remote_device_incarnation())); + EXPECT_TRUE(remote_device_mgr.ContainsDevice( + th0->resource_remote_device_incarnation())); + + th0->Unref(); + th1->Unref(); + ctx->Unref(); +} + } // namespace tensorflow