From ac191b55913c3958ada1f64785772692b33c3cf4 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 25 Mar 2020 10:48:30 -0700 Subject: [PATCH] [Functional ops] Move handle cache back into `IfOp` and `WhileOp`. Previously, we were relying on the idempotency of `FunctionLibraryRuntime::Instantiate()` to cache the mapping from {FLR, Function} to FHandle. However, `Instantiate()` is a heavyweight method that acquires multiple exclusive locks, which causes contention when the kernel is invoked concurrently by many threads (e.g. during inference). By moving the cache (back, in the case of WhileOp) to the kernels, we can use a more appropriate `tf_shared_lock`, which reduces contention. PiperOrigin-RevId: 302923478 Change-Id: I9372c72fe34d98d390c970dbc5e033d204292e21 --- tensorflow/core/kernels/functional_ops.cc | 108 +++++++++++++++++----- 1 file changed, 84 insertions(+), 24 deletions(-) diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index eb6b5cdce3a..c8cebd0ff4d 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -123,22 +123,10 @@ class IfOp : public AsyncOpKernel { ~IfOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - auto lib = ctx->function_library(); - OP_REQUIRES_ASYNC(ctx, lib != nullptr, - errors::Internal("No function library"), done); - - // TODO(b/37549631): Because this op has `SetIsStateful()` in its op - // registration, this kernel may be shared by multiple subgraphs, which have - // different associated `FunctionLibraryRuntime` objects and hence different - // `FHandle` namespaces. So we must call Instantiate() to make sure we get - // the correct function handles with respect to `lib`. Note the underlying - // `lib->Instantiate()` caches the created function handles, so calling - // `Instantiate()` repeatedly on the same `lib` and function is cheap. FHandle then_handle; FHandle else_handle; - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done); - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done); - + OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &then_handle, &else_handle), + done); bool cond; OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond)); (new State(this, ctx, cond, then_handle, else_handle, done))->Start(); @@ -148,6 +136,10 @@ class IfOp : public AsyncOpKernel { NameAttrList then_func_; NameAttrList else_func_; + mutex mu_; + std::unordered_map> + handles_ GUARDED_BY(mu_); + class State { public: State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle, @@ -203,6 +195,42 @@ class IfOp : public AsyncOpKernel { TensorVec args_; TensorVec rets_; }; + + Status GetHandles(OpKernelContext* ctx, FHandle* then_handle, + FHandle* else_handle) { + // TODO(b/37549631): Because this op has `SetIsStateful()` in its + // op registration, this kernel may be shared by multiple + // subgraphs, which have different associated + // `FunctionLibraryRuntime` objects and hence different `FHandle` + // namespaces. We currently work around this by caching the map + // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two + // functions this op uses. + auto lib = ctx->function_library(); + if (lib == nullptr) return errors::Internal("No function library"); + *then_handle = kInvalidHandle; + *else_handle = kInvalidHandle; + { + tf_shared_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *then_handle = iter->second.first; + *else_handle = iter->second.second; + } + } + if (TF_PREDICT_FALSE(*then_handle == kInvalidHandle)) { + mutex_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *then_handle = iter->second.first; + *else_handle = iter->second.second; + } else { + TF_RETURN_IF_ERROR(Instantiate(lib, then_func_, then_handle)); + TF_RETURN_IF_ERROR(Instantiate(lib, else_func_, else_handle)); + handles_[lib] = {*then_handle, *else_handle}; + } + } + return Status::OK(); + } }; class CaseOp : public AsyncOpKernel { @@ -332,18 +360,10 @@ class WhileOp : public AsyncOpKernel { auto lib = ctx->function_library(); OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library"), done); - - // TODO(b/37549631): Because this op has `SetIsStateful()` in its op - // registration, this kernel may be shared by multiple subgraphs, which have - // different associated `FunctionLibraryRuntime` objects and hence different - // `FHandle` namespaces. So we must call Instantiate() to make sure we get - // the correct function handles with respect to `lib`. Note the underlying - // `lib->Instantiate()` caches the created function handles, so calling - // `Instantiate()` repeatedly on the same `lib` and function is cheap. FHandle cond_handle; FHandle body_handle; - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done); - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done); + OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle), + done); (new State(this, ctx, cond_handle, body_handle, done))->Start(); } @@ -351,6 +371,10 @@ class WhileOp : public AsyncOpKernel { NameAttrList cond_func_; NameAttrList body_func_; + mutex mu_; + std::unordered_map> + handles_ GUARDED_BY(mu_); + class State { public: State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, @@ -486,6 +510,42 @@ class WhileOp : public AsyncOpKernel { delete this; } }; + + Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle, + FHandle* body_handle) { + // TODO(b/37549631): Because this op has `SetIsStateful()` in its + // op registration, this kernel may be shared by multiple + // subgraphs, which have different associated + // `FunctionLibraryRuntime` objects and hence different `FHandle` + // namespaces. We currently work around this by caching the map + // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two + // functions this op uses. + auto lib = ctx->function_library(); + if (lib == nullptr) return errors::Internal("No function library"); + *cond_handle = kInvalidHandle; + *body_handle = kInvalidHandle; + { + tf_shared_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *cond_handle = iter->second.first; + *body_handle = iter->second.second; + } + } + if (TF_PREDICT_FALSE(*cond_handle == kInvalidHandle)) { + mutex_lock l(mu_); + const auto iter = handles_.find(lib); + if (TF_PREDICT_TRUE(iter != handles_.end())) { + *cond_handle = iter->second.first; + *body_handle = iter->second.second; + } else { + TF_RETURN_IF_ERROR(Instantiate(lib, cond_func_, cond_handle)); + TF_RETURN_IF_ERROR(Instantiate(lib, body_func_, body_handle)); + handles_[lib] = {*cond_handle, *body_handle}; + } + } + return Status::OK(); + } }; // TODO(drpng): remove these. REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);