diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index eb6b5cdce3a..c8cebd0ff4d 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -123,22 +123,10 @@ class IfOp : public AsyncOpKernel { ~IfOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - auto lib = ctx->function_library(); - OP_REQUIRES_ASYNC(ctx, lib != nullptr, - errors::Internal("No function library"), done); - - // TODO(b/37549631): Because this op has `SetIsStateful()` in its op - // registration, this kernel may be shared by multiple subgraphs, which have - // different associated `FunctionLibraryRuntime` objects and hence different - // `FHandle` namespaces. So we must call Instantiate() to make sure we get - // the correct function handles with respect to `lib`. Note the underlying - // `lib->Instantiate()` caches the created function handles, so calling - // `Instantiate()` repeatedly on the same `lib` and function is cheap. FHandle then_handle; FHandle else_handle; - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done); - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done); - + OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &then_handle, &else_handle), + done); bool cond; OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond)); (new State(this, ctx, cond, then_handle, else_handle, done))->Start(); @@ -148,6 +136,10 @@ class IfOp : public AsyncOpKernel { NameAttrList then_func_; NameAttrList else_func_; + mutex mu_; + std::unordered_map> + handles_ GUARDED_BY(mu_); + class State { public: State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle, @@ -203,6 +195,42 @@ class IfOp : public AsyncOpKernel { TensorVec args_; TensorVec rets_; }; + + Status GetHandles(OpKernelContext* ctx, FHandle* then_handle, + FHandle* else_handle) { + // TODO(b/37549631): Because this op has `SetIsStateful()` in its + // op registration, this kernel may be shared by multiple + // subgraphs, which have different associated + // `FunctionLibraryRuntime` objects and hence different `FHandle` + // namespaces. We currently work around this by caching the map + // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two + // functions this op uses. + auto lib = ctx->function_library(); + if (lib == nullptr) return errors::Internal("No function library"); + *then_handle = kInvalidHandle; + *else_handle = kInvalidHandle; + { + tf_shared_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *then_handle = iter->second.first; + *else_handle = iter->second.second; + } + } + if (TF_PREDICT_FALSE(*then_handle == kInvalidHandle)) { + mutex_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *then_handle = iter->second.first; + *else_handle = iter->second.second; + } else { + TF_RETURN_IF_ERROR(Instantiate(lib, then_func_, then_handle)); + TF_RETURN_IF_ERROR(Instantiate(lib, else_func_, else_handle)); + handles_[lib] = {*then_handle, *else_handle}; + } + } + return Status::OK(); + } }; class CaseOp : public AsyncOpKernel { @@ -332,18 +360,10 @@ class WhileOp : public AsyncOpKernel { auto lib = ctx->function_library(); OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library"), done); - - // TODO(b/37549631): Because this op has `SetIsStateful()` in its op - // registration, this kernel may be shared by multiple subgraphs, which have - // different associated `FunctionLibraryRuntime` objects and hence different - // `FHandle` namespaces. So we must call Instantiate() to make sure we get - // the correct function handles with respect to `lib`. Note the underlying - // `lib->Instantiate()` caches the created function handles, so calling - // `Instantiate()` repeatedly on the same `lib` and function is cheap. FHandle cond_handle; FHandle body_handle; - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done); - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done); + OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle), + done); (new State(this, ctx, cond_handle, body_handle, done))->Start(); } @@ -351,6 +371,10 @@ class WhileOp : public AsyncOpKernel { NameAttrList cond_func_; NameAttrList body_func_; + mutex mu_; + std::unordered_map> + handles_ GUARDED_BY(mu_); + class State { public: State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, @@ -486,6 +510,42 @@ class WhileOp : public AsyncOpKernel { delete this; } }; + + Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle, + FHandle* body_handle) { + // TODO(b/37549631): Because this op has `SetIsStateful()` in its + // op registration, this kernel may be shared by multiple + // subgraphs, which have different associated + // `FunctionLibraryRuntime` objects and hence different `FHandle` + // namespaces. We currently work around this by caching the map + // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two + // functions this op uses. + auto lib = ctx->function_library(); + if (lib == nullptr) return errors::Internal("No function library"); + *cond_handle = kInvalidHandle; + *body_handle = kInvalidHandle; + { + tf_shared_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *cond_handle = iter->second.first; + *body_handle = iter->second.second; + } + } + if (TF_PREDICT_FALSE(*cond_handle == kInvalidHandle)) { + mutex_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *cond_handle = iter->second.first; + *body_handle = iter->second.second; + } else { + TF_RETURN_IF_ERROR(Instantiate(lib, cond_func_, cond_handle)); + TF_RETURN_IF_ERROR(Instantiate(lib, body_func_, body_handle)); + handles_[lib] = {*cond_handle, *body_handle}; + } + } + return Status::OK(); + } }; // TODO(drpng): remove these. REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);