[FLR] Switch to using LocalHandle in FLRImpl::GetOrCreateItem()
.
This enables us to avoid two lock-protected lookups in the parent ProcFLR's handle map (the first for `parent_->IsInstantiatedOnDevice()`, and the second in `GetOrCreateItem()`) every time we invoke a function. PiperOrigin-RevId: 220209369
This commit is contained in:
parent
a7faefeb5c
commit
be4dc89e8c
@ -382,8 +382,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
FunctionBody** fbody);
|
||||
Status CreateItem(Handle handle, Item** item);
|
||||
Status GetOrCreateItem(Handle handle, Item** item);
|
||||
Status CreateItem(Item** item);
|
||||
Status GetOrCreateItem(LocalHandle local_handle, Item** item);
|
||||
Status InstantiateSymbolicGradient(const NameAttrList& func,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
FunctionBody** g_body);
|
||||
@ -691,13 +691,14 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
|
||||
}
|
||||
|
||||
LocalHandle local_handle;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
*handle = parent_->GetHandle(key);
|
||||
if (*handle != kInvalidHandle) {
|
||||
delete fbody;
|
||||
++items_[parent_->GetHandleOnDevice(device_name_, *handle)]
|
||||
->instantiation_counter;
|
||||
local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
|
||||
++items_[local_handle]->instantiation_counter;
|
||||
} else {
|
||||
*handle = parent_->AddHandle(key, device_name_, next_handle_);
|
||||
Item* item = new Item;
|
||||
@ -709,26 +710,24 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
|
||||
item->overlay_flr =
|
||||
new FunctionLibraryRuntimeOverlay(this, options.overlay_lib);
|
||||
}
|
||||
items_.emplace(next_handle_, std::unique_ptr<Item>(item));
|
||||
next_handle_++;
|
||||
local_handle = next_handle_++;
|
||||
items_.emplace(local_handle, std::unique_ptr<Item>(item));
|
||||
}
|
||||
}
|
||||
|
||||
if (options.create_kernels_eagerly) {
|
||||
Item* item;
|
||||
TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item));
|
||||
TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
|
||||
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
|
||||
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
|
||||
if (h == kInvalidLocalHandle) {
|
||||
return parent_->ReleaseHandle(handle);
|
||||
}
|
||||
|
||||
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
|
||||
CHECK_NE(h, kInvalidLocalHandle);
|
||||
mutex_lock l(mu_);
|
||||
CHECK_EQ(1, items_.count(h));
|
||||
std::unique_ptr<Item>& item = items_[h];
|
||||
@ -789,7 +788,7 @@ void PruneFunctionBody(Graph* g) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
|
||||
Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
|
||||
const FunctionBody* fbody;
|
||||
const FunctionLibraryDefinition* lib_def;
|
||||
string executor_type;
|
||||
@ -843,13 +842,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
|
||||
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
|
||||
Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
|
||||
Item** item) {
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
auto iter = items_.find(local_handle);
|
||||
if (iter == items_.end()) {
|
||||
return errors::NotFound("Function handle ", handle,
|
||||
return errors::Internal("Local function handle ", local_handle,
|
||||
" is not valid. Likely an internal error.");
|
||||
}
|
||||
*item = iter->second.get();
|
||||
@ -859,7 +858,7 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
|
||||
}
|
||||
// NOTE: We need to call CreateItem out of mu_ because creating an
|
||||
// executor needs to call CreateKernel.
|
||||
return CreateItem(handle, item);
|
||||
return CreateItem(item);
|
||||
}
|
||||
|
||||
void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
|
||||
@ -994,7 +993,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
};
|
||||
}
|
||||
|
||||
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
|
||||
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
|
||||
if (local_handle == kInvalidLocalHandle) {
|
||||
parent_->Run(run_opts, handle, args, rets, done);
|
||||
return;
|
||||
}
|
||||
@ -1005,7 +1005,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
DCHECK(run_opts.runner != nullptr);
|
||||
|
||||
Item* item = nullptr;
|
||||
Status s = GetOrCreateItem(handle, &item);
|
||||
Status s = GetOrCreateItem(local_handle, &item);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
@ -1052,8 +1052,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
done(errors::Cancelled(""));
|
||||
return;
|
||||
}
|
||||
if (!parent_->IsInstantiatedOnDevice(device_name_, handle) ||
|
||||
opts.remote_execution) {
|
||||
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
|
||||
if (local_handle == kInvalidLocalHandle || opts.remote_execution) {
|
||||
done(errors::Unimplemented("Remote calling with CallFrameInterface"));
|
||||
return;
|
||||
}
|
||||
@ -1074,7 +1074,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
}
|
||||
|
||||
Item* item = nullptr;
|
||||
Status s = GetOrCreateItem(handle, &item);
|
||||
Status s = GetOrCreateItem(local_handle, &item);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
@ -1097,7 +1097,8 @@ bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
|
||||
|
||||
string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
|
||||
Item* item = nullptr;
|
||||
Status s = GetOrCreateItem(handle, &item);
|
||||
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
|
||||
Status s = GetOrCreateItem(local_handle, &item);
|
||||
if (s.ok()) {
|
||||
return tensorflow::DebugString(item->graph);
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user