Preserve composite devices when cloning a ProcessFunctionLibraryRuntime.

PiperOrigin-RevId: 324919856
Change-Id: Ibfe2df7e511593730ebe54e57204b136b0fff5a9
This commit is contained in:
Yujing Zhang 2020-08-04 16:57:42 -07:00 committed by TensorFlower Gardener
parent d4a8e6515e
commit 79594069bb
3 changed files with 30 additions and 0 deletions

View File

@ -1667,6 +1667,10 @@ Status ProcessFunctionLibraryRuntime::Clone(
device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
custom_kernel_creator, session_metadata_, rendezvous_factory_);
{
tf_shared_lock l(mu_);
for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d);
}
return Status::OK();
}

View File

@ -221,6 +221,7 @@ class ProcessFunctionLibraryRuntime {
void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
device_set_->AddDevice(d);
composite_devices_.push_back(d);
}
protected:
@ -452,6 +453,9 @@ class ProcessFunctionLibraryRuntime {
// fail if it spans the changed remote devices.
std::shared_ptr<DeviceSet> device_set_ TF_GUARDED_BY(mu_);
// Composite devices owned by a EagerContext.
std::vector<CompositeDevice*> composite_devices_ TF_GUARDED_BY(mu_);
// Holds all the function instantiations. Maps function_keys to handles.
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
TF_GUARDED_BY(mu_);

View File

@ -1188,6 +1188,28 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) {
EXPECT_EQ(session_metadata.version(), read_metadata.version());
}
TEST_F(ProcessFunctionLibraryRuntimeTest, CompositeDevicesAfterCloning) {
Init({AddVarAcrossDevices()});
Status s;
std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
/*unique_device_id=*/0,
device_mgr_->HostCPU()->parsed_name(), &s);
TF_ASSERT_OK(s);
AddCompositeDevice(composite_device.get());
auto* flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
ASSERT_NE(nullptr, flr);
std::unique_ptr<FunctionLibraryDefinition> cloned_lib_def;
std::unique_ptr<ProcessFunctionLibraryRuntime> cloned_proc_flr;
FunctionLibraryRuntime* cloned_flr;
TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
EXPECT_EQ(
cloned_proc_flr->device_set()->FindDeviceByName(composite_device->name()),
composite_device.get());
}
TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
const SessionMetadata session_metadata = GenerateSessionMetadata();
Init({SessionMetadataReaderOpFn()}, &session_metadata);