[FLR] Switch to using LocalHandle in FLRImpl::GetOrCreateItem().

This enables us to avoid two lock-protected lookups in the parent ProcFLR's handle map (the first for `parent_->IsInstantiatedOnDevice()`, and the second in `GetOrCreateItem()`) every time we invoke a function.

PiperOrigin-RevId: 220209369
This commit is contained in:
Derek Murray 2018-11-05 18:25:19 -08:00 committed by TensorFlower Gardener
parent a7faefeb5c
commit be4dc89e8c

View File

@ -382,8 +382,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
const FunctionLibraryDefinition* lib_def,
FunctionBody** fbody);
Status CreateItem(Handle handle, Item** item);
Status GetOrCreateItem(Handle handle, Item** item);
Status CreateItem(Item** item);
Status GetOrCreateItem(LocalHandle local_handle, Item** item);
Status InstantiateSymbolicGradient(const NameAttrList& func,
const FunctionLibraryDefinition* lib_def,
FunctionBody** g_body);
@ -691,13 +691,14 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
}
LocalHandle local_handle;
{
mutex_lock l(mu_);
*handle = parent_->GetHandle(key);
if (*handle != kInvalidHandle) {
delete fbody;
++items_[parent_->GetHandleOnDevice(device_name_, *handle)]
->instantiation_counter;
local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
++items_[local_handle]->instantiation_counter;
} else {
*handle = parent_->AddHandle(key, device_name_, next_handle_);
Item* item = new Item;
@ -709,26 +710,24 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
item->overlay_flr =
new FunctionLibraryRuntimeOverlay(this, options.overlay_lib);
}
items_.emplace(next_handle_, std::unique_ptr<Item>(item));
next_handle_++;
local_handle = next_handle_++;
items_.emplace(local_handle, std::unique_ptr<Item>(item));
}
}
if (options.create_kernels_eagerly) {
Item* item;
TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item));
TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
}
return Status::OK();
}
Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
if (h == kInvalidLocalHandle) {
return parent_->ReleaseHandle(handle);
}
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
CHECK_NE(h, kInvalidLocalHandle);
mutex_lock l(mu_);
CHECK_EQ(1, items_.count(h));
std::unique_ptr<Item>& item = items_[h];
@ -789,7 +788,7 @@ void PruneFunctionBody(Graph* g) {
}
} // namespace
Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
const FunctionBody* fbody;
const FunctionLibraryDefinition* lib_def;
string executor_type;
@ -843,13 +842,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
return Status::OK();
}
Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
Item** item) {
{
tf_shared_lock l(mu_);
auto iter = items_.find(local_handle);
if (iter == items_.end()) {
return errors::NotFound("Function handle ", handle,
return errors::Internal("Local function handle ", local_handle,
" is not valid. Likely an internal error.");
}
*item = iter->second.get();
@ -859,7 +858,7 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
}
// NOTE: We need to call CreateItem out of mu_ because creating an
// executor needs to call CreateKernel.
return CreateItem(handle, item);
return CreateItem(item);
}
void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
@ -994,7 +993,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
};
}
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
if (local_handle == kInvalidLocalHandle) {
parent_->Run(run_opts, handle, args, rets, done);
return;
}
@ -1005,7 +1005,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
DCHECK(run_opts.runner != nullptr);
Item* item = nullptr;
Status s = GetOrCreateItem(handle, &item);
Status s = GetOrCreateItem(local_handle, &item);
if (!s.ok()) {
done(s);
return;
@ -1052,8 +1052,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
done(errors::Cancelled(""));
return;
}
if (!parent_->IsInstantiatedOnDevice(device_name_, handle) ||
opts.remote_execution) {
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
if (local_handle == kInvalidLocalHandle || opts.remote_execution) {
done(errors::Unimplemented("Remote calling with CallFrameInterface"));
return;
}
@ -1074,7 +1074,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
}
Item* item = nullptr;
Status s = GetOrCreateItem(handle, &item);
Status s = GetOrCreateItem(local_handle, &item);
if (!s.ok()) {
done(s);
return;
@ -1097,7 +1097,8 @@ bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
Item* item = nullptr;
Status s = GetOrCreateItem(handle, &item);
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
Status s = GetOrCreateItem(local_handle, &item);
if (s.ok()) {
return tensorflow::DebugString(item->graph);
} else {