diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 9affc9fb188..313cdfa2200 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -382,8 +382,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, const FunctionLibraryDefinition* lib_def, FunctionBody** fbody); - Status CreateItem(Handle handle, Item** item); - Status GetOrCreateItem(Handle handle, Item** item); + Status CreateItem(Item** item); + Status GetOrCreateItem(LocalHandle local_handle, Item** item); Status InstantiateSymbolicGradient(const NameAttrList& func, const FunctionLibraryDefinition* lib_def, FunctionBody** g_body); @@ -691,13 +691,14 @@ Status FunctionLibraryRuntimeImpl::Instantiate( TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody)); } + LocalHandle local_handle; { mutex_lock l(mu_); *handle = parent_->GetHandle(key); if (*handle != kInvalidHandle) { delete fbody; - ++items_[parent_->GetHandleOnDevice(device_name_, *handle)] - ->instantiation_counter; + local_handle = parent_->GetHandleOnDevice(device_name_, *handle); + ++items_[local_handle]->instantiation_counter; } else { *handle = parent_->AddHandle(key, device_name_, next_handle_); Item* item = new Item; @@ -709,26 +710,24 @@ Status FunctionLibraryRuntimeImpl::Instantiate( item->overlay_flr = new FunctionLibraryRuntimeOverlay(this, options.overlay_lib); } - items_.emplace(next_handle_, std::unique_ptr(item)); - next_handle_++; + local_handle = next_handle_++; + items_.emplace(local_handle, std::unique_ptr(item)); } } if (options.create_kernels_eagerly) { Item* item; - TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item)); + TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item)); } return Status::OK(); } Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { - if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { + LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle); + if (h == kInvalidLocalHandle) { return parent_->ReleaseHandle(handle); } - - LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle); - CHECK_NE(h, kInvalidLocalHandle); mutex_lock l(mu_); CHECK_EQ(1, items_.count(h)); std::unique_ptr& item = items_[h]; @@ -789,7 +788,7 @@ void PruneFunctionBody(Graph* g) { } } // namespace -Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { +Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) { const FunctionBody* fbody; const FunctionLibraryDefinition* lib_def; string executor_type; @@ -843,13 +842,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { return Status::OK(); } -Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { - LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); +Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle, + Item** item) { { tf_shared_lock l(mu_); auto iter = items_.find(local_handle); if (iter == items_.end()) { - return errors::NotFound("Function handle ", handle, + return errors::Internal("Local function handle ", local_handle, " is not valid. Likely an internal error."); } *item = iter->second.get(); @@ -859,7 +858,7 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { } // NOTE: We need to call CreateItem out of mu_ because creating an // executor needs to call CreateKernel. - return CreateItem(handle, item); + return CreateItem(item); } void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions( @@ -994,7 +993,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, }; } - if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { + LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); + if (local_handle == kInvalidLocalHandle) { parent_->Run(run_opts, handle, args, rets, done); return; } @@ -1005,7 +1005,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, DCHECK(run_opts.runner != nullptr); Item* item = nullptr; - Status s = GetOrCreateItem(handle, &item); + Status s = GetOrCreateItem(local_handle, &item); if (!s.ok()) { done(s); return; @@ -1052,8 +1052,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, done(errors::Cancelled("")); return; } - if (!parent_->IsInstantiatedOnDevice(device_name_, handle) || - opts.remote_execution) { + LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); + if (local_handle == kInvalidLocalHandle || opts.remote_execution) { done(errors::Unimplemented("Remote calling with CallFrameInterface")); return; } @@ -1074,7 +1074,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, } Item* item = nullptr; - Status s = GetOrCreateItem(handle, &item); + Status s = GetOrCreateItem(local_handle, &item); if (!s.ok()) { done(s); return; @@ -1097,7 +1097,8 @@ bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) { string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { Item* item = nullptr; - Status s = GetOrCreateItem(handle, &item); + LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); + Status s = GetOrCreateItem(local_handle, &item); if (s.ok()) { return tensorflow::DebugString(item->graph); } else {