diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 515477cd16a..aee482d92da 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -1667,6 +1667,10 @@ Status ProcessFunctionLibraryRuntime::Clone( device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version, out_lib_def->get(), optimizer_options, default_thread_pool_, parent_, custom_kernel_creator, session_metadata_, rendezvous_factory_); + { + tf_shared_lock l(mu_); + for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d); + } return Status::OK(); } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index bc68c9c2807..0bd85c62df5 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -221,6 +221,7 @@ class ProcessFunctionLibraryRuntime { void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); device_set_->AddDevice(d); + composite_devices_.push_back(d); } protected: @@ -452,6 +453,9 @@ class ProcessFunctionLibraryRuntime { // fail if it spans the changed remote devices. std::shared_ptr device_set_ TF_GUARDED_BY(mu_); + // Composite devices owned by a EagerContext. + std::vector composite_devices_ TF_GUARDED_BY(mu_); + // Holds all the function instantiations. Maps function_keys to handles. std::unordered_map table_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 19c33a53d20..be279c84d1a 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -1188,6 +1188,28 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) { EXPECT_EQ(session_metadata.version(), read_metadata.version()); } +TEST_F(ProcessFunctionLibraryRuntimeTest, CompositeDevicesAfterCloning) { + Init({AddVarAcrossDevices()}); + + Status s; + std::unique_ptr composite_device = + CompositeDevice::MakeDevice({device0_->name(), device1_->name()}, + /*unique_device_id=*/0, + device_mgr_->HostCPU()->parsed_name(), &s); + TF_ASSERT_OK(s); + AddCompositeDevice(composite_device.get()); + + auto* flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0"); + ASSERT_NE(nullptr, flr); + std::unique_ptr cloned_lib_def; + std::unique_ptr cloned_proc_flr; + FunctionLibraryRuntime* cloned_flr; + TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr)); + EXPECT_EQ( + cloned_proc_flr->device_set()->FindDeviceByName(composite_device->name()), + composite_device.get()); +} + TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) { const SessionMetadata session_metadata = GenerateSessionMetadata(); Init({SessionMetadataReaderOpFn()}, &session_metadata);