Preserve composite devices when cloning a ProcessFunctionLibraryRuntime.
PiperOrigin-RevId: 324919856 Change-Id: Ibfe2df7e511593730ebe54e57204b136b0fff5a9
This commit is contained in:
parent
d4a8e6515e
commit
79594069bb
@ -1667,6 +1667,10 @@ Status ProcessFunctionLibraryRuntime::Clone(
|
|||||||
device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
|
device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
|
||||||
out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
|
out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
|
||||||
custom_kernel_creator, session_metadata_, rendezvous_factory_);
|
custom_kernel_creator, session_metadata_, rendezvous_factory_);
|
||||||
|
{
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -221,6 +221,7 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) {
|
void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
device_set_->AddDevice(d);
|
device_set_->AddDevice(d);
|
||||||
|
composite_devices_.push_back(d);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -452,6 +453,9 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
// fail if it spans the changed remote devices.
|
// fail if it spans the changed remote devices.
|
||||||
std::shared_ptr<DeviceSet> device_set_ TF_GUARDED_BY(mu_);
|
std::shared_ptr<DeviceSet> device_set_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// Composite devices owned by a EagerContext.
|
||||||
|
std::vector<CompositeDevice*> composite_devices_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
// Holds all the function instantiations. Maps function_keys to handles.
|
// Holds all the function instantiations. Maps function_keys to handles.
|
||||||
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
|
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
|
||||||
TF_GUARDED_BY(mu_);
|
TF_GUARDED_BY(mu_);
|
||||||
|
@ -1188,6 +1188,28 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) {
|
|||||||
EXPECT_EQ(session_metadata.version(), read_metadata.version());
|
EXPECT_EQ(session_metadata.version(), read_metadata.version());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ProcessFunctionLibraryRuntimeTest, CompositeDevicesAfterCloning) {
|
||||||
|
Init({AddVarAcrossDevices()});
|
||||||
|
|
||||||
|
Status s;
|
||||||
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
|
CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
|
||||||
|
/*unique_device_id=*/0,
|
||||||
|
device_mgr_->HostCPU()->parsed_name(), &s);
|
||||||
|
TF_ASSERT_OK(s);
|
||||||
|
AddCompositeDevice(composite_device.get());
|
||||||
|
|
||||||
|
auto* flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
|
||||||
|
ASSERT_NE(nullptr, flr);
|
||||||
|
std::unique_ptr<FunctionLibraryDefinition> cloned_lib_def;
|
||||||
|
std::unique_ptr<ProcessFunctionLibraryRuntime> cloned_proc_flr;
|
||||||
|
FunctionLibraryRuntime* cloned_flr;
|
||||||
|
TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
|
||||||
|
EXPECT_EQ(
|
||||||
|
cloned_proc_flr->device_set()->FindDeviceByName(composite_device->name()),
|
||||||
|
composite_device.get());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
|
||||||
const SessionMetadata session_metadata = GenerateSessionMetadata();
|
const SessionMetadata session_metadata = GenerateSessionMetadata();
|
||||||
Init({SessionMetadataReaderOpFn()}, &session_metadata);
|
Init({SessionMetadataReaderOpFn()}, &session_metadata);
|
||||||
|
Loading…
Reference in New Issue
Block a user