[FLR] Switch to using LocalHandle in FLRImpl::GetOrCreateItem().

This enables us to avoid two lock-protected lookups in the parent ProcFLR's handle map (the first for `parent_->IsInstantiatedOnDevice()`, and the second in `GetOrCreateItem()`) every time we invoke a function.

PiperOrigin-RevId: 220209369
This commit is contained in:
Derek Murray 2018-11-05 18:25:19 -08:00 committed by TensorFlower Gardener
parent a7faefeb5c
commit be4dc89e8c

View File

@ -382,8 +382,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
const FunctionLibraryDefinition* lib_def, const FunctionLibraryDefinition* lib_def,
FunctionBody** fbody); FunctionBody** fbody);
Status CreateItem(Handle handle, Item** item); Status CreateItem(Item** item);
Status GetOrCreateItem(Handle handle, Item** item); Status GetOrCreateItem(LocalHandle local_handle, Item** item);
Status InstantiateSymbolicGradient(const NameAttrList& func, Status InstantiateSymbolicGradient(const NameAttrList& func,
const FunctionLibraryDefinition* lib_def, const FunctionLibraryDefinition* lib_def,
FunctionBody** g_body); FunctionBody** g_body);
@ -691,13 +691,14 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody)); TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
} }
LocalHandle local_handle;
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
*handle = parent_->GetHandle(key); *handle = parent_->GetHandle(key);
if (*handle != kInvalidHandle) { if (*handle != kInvalidHandle) {
delete fbody; delete fbody;
++items_[parent_->GetHandleOnDevice(device_name_, *handle)] local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
->instantiation_counter; ++items_[local_handle]->instantiation_counter;
} else { } else {
*handle = parent_->AddHandle(key, device_name_, next_handle_); *handle = parent_->AddHandle(key, device_name_, next_handle_);
Item* item = new Item; Item* item = new Item;
@ -709,26 +710,24 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
item->overlay_flr = item->overlay_flr =
new FunctionLibraryRuntimeOverlay(this, options.overlay_lib); new FunctionLibraryRuntimeOverlay(this, options.overlay_lib);
} }
items_.emplace(next_handle_, std::unique_ptr<Item>(item)); local_handle = next_handle_++;
next_handle_++; items_.emplace(local_handle, std::unique_ptr<Item>(item));
} }
} }
if (options.create_kernels_eagerly) { if (options.create_kernels_eagerly) {
Item* item; Item* item;
TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item)); TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
} }
return Status::OK(); return Status::OK();
} }
Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
if (h == kInvalidLocalHandle) {
return parent_->ReleaseHandle(handle); return parent_->ReleaseHandle(handle);
} }
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
CHECK_NE(h, kInvalidLocalHandle);
mutex_lock l(mu_); mutex_lock l(mu_);
CHECK_EQ(1, items_.count(h)); CHECK_EQ(1, items_.count(h));
std::unique_ptr<Item>& item = items_[h]; std::unique_ptr<Item>& item = items_[h];
@ -789,7 +788,7 @@ void PruneFunctionBody(Graph* g) {
} }
} // namespace } // namespace
Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
const FunctionBody* fbody; const FunctionBody* fbody;
const FunctionLibraryDefinition* lib_def; const FunctionLibraryDefinition* lib_def;
string executor_type; string executor_type;
@ -843,13 +842,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
return Status::OK(); return Status::OK();
} }
Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); Item** item) {
{ {
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
auto iter = items_.find(local_handle); auto iter = items_.find(local_handle);
if (iter == items_.end()) { if (iter == items_.end()) {
return errors::NotFound("Function handle ", handle, return errors::Internal("Local function handle ", local_handle,
" is not valid. Likely an internal error."); " is not valid. Likely an internal error.");
} }
*item = iter->second.get(); *item = iter->second.get();
@ -859,7 +858,7 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
} }
// NOTE: We need to call CreateItem out of mu_ because creating an // NOTE: We need to call CreateItem out of mu_ because creating an
// executor needs to call CreateKernel. // executor needs to call CreateKernel.
return CreateItem(handle, item); return CreateItem(item);
} }
void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions( void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
@ -994,7 +993,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
}; };
} }
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
if (local_handle == kInvalidLocalHandle) {
parent_->Run(run_opts, handle, args, rets, done); parent_->Run(run_opts, handle, args, rets, done);
return; return;
} }
@ -1005,7 +1005,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
DCHECK(run_opts.runner != nullptr); DCHECK(run_opts.runner != nullptr);
Item* item = nullptr; Item* item = nullptr;
Status s = GetOrCreateItem(handle, &item); Status s = GetOrCreateItem(local_handle, &item);
if (!s.ok()) { if (!s.ok()) {
done(s); done(s);
return; return;
@ -1052,8 +1052,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
done(errors::Cancelled("")); done(errors::Cancelled(""));
return; return;
} }
if (!parent_->IsInstantiatedOnDevice(device_name_, handle) || LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
opts.remote_execution) { if (local_handle == kInvalidLocalHandle || opts.remote_execution) {
done(errors::Unimplemented("Remote calling with CallFrameInterface")); done(errors::Unimplemented("Remote calling with CallFrameInterface"));
return; return;
} }
@ -1074,7 +1074,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
} }
Item* item = nullptr; Item* item = nullptr;
Status s = GetOrCreateItem(handle, &item); Status s = GetOrCreateItem(local_handle, &item);
if (!s.ok()) { if (!s.ok()) {
done(s); done(s);
return; return;
@ -1097,7 +1097,8 @@ bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
Item* item = nullptr; Item* item = nullptr;
Status s = GetOrCreateItem(handle, &item); LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
Status s = GetOrCreateItem(local_handle, &item);
if (s.ok()) { if (s.ok()) {
return tensorflow::DebugString(item->graph); return tensorflow::DebugString(item->graph);
} else { } else {