[Functional ops] Move handle cache back into IfOp and WhileOp.

Previously, we were relying on the idempotency of `FunctionLibraryRuntime::Instantiate()` to cache the mapping from {FLR, Function} to FHandle. However, `Instantiate()` is a heavyweight method that acquires multiple exclusive locks, which causes contention when the kernel is invoked concurrently by many threads (e.g. during inference). By moving the cache (back, in the case of WhileOp) to the kernels, we can use a more appropriate `tf_shared_lock`, which reduces contention.

PiperOrigin-RevId: 302923478
Change-Id: I9372c72fe34d98d390c970dbc5e033d204292e21
This commit is contained in:
Derek Murray 2020-03-25 10:48:30 -07:00 committed by TensorFlower Gardener
parent f1e0098f2a
commit ac191b5591

View File

@ -123,22 +123,10 @@ class IfOp : public AsyncOpKernel {
~IfOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
auto lib = ctx->function_library();
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library"), done);
// TODO(b/37549631): Because this op has `SetIsStateful()` in its op
// registration, this kernel may be shared by multiple subgraphs, which have
// different associated `FunctionLibraryRuntime` objects and hence different
// `FHandle` namespaces. So we must call Instantiate() to make sure we get
// the correct function handles with respect to `lib`. Note the underlying
// `lib->Instantiate()` caches the created function handles, so calling
// `Instantiate()` repeatedly on the same `lib` and function is cheap.
FHandle then_handle;
FHandle else_handle;
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done);
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done);
OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &then_handle, &else_handle),
done);
bool cond;
OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
(new State(this, ctx, cond, then_handle, else_handle, done))->Start();
@ -148,6 +136,10 @@ class IfOp : public AsyncOpKernel {
NameAttrList then_func_;
NameAttrList else_func_;
mutex mu_;
std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
handles_ GUARDED_BY(mu_);
class State {
public:
State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
@ -203,6 +195,42 @@ class IfOp : public AsyncOpKernel {
TensorVec args_;
TensorVec rets_;
};
Status GetHandles(OpKernelContext* ctx, FHandle* then_handle,
FHandle* else_handle) {
// TODO(b/37549631): Because this op has `SetIsStateful()` in its
// op registration, this kernel may be shared by multiple
// subgraphs, which have different associated
// `FunctionLibraryRuntime` objects and hence different `FHandle`
// namespaces. We currently work around this by caching the map
// from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
// functions this op uses.
auto lib = ctx->function_library();
if (lib == nullptr) return errors::Internal("No function library");
*then_handle = kInvalidHandle;
*else_handle = kInvalidHandle;
{
tf_shared_lock l(mu_);
const auto iter = handles_.find(lib);
if (TF_PREDICT_TRUE(iter != handles_.end())) {
*then_handle = iter->second.first;
*else_handle = iter->second.second;
}
}
if (TF_PREDICT_FALSE(*then_handle == kInvalidHandle)) {
mutex_lock l(mu_);
const auto iter = handles_.find(lib);
if (TF_PREDICT_TRUE(iter != handles_.end())) {
*then_handle = iter->second.first;
*else_handle = iter->second.second;
} else {
TF_RETURN_IF_ERROR(Instantiate(lib, then_func_, then_handle));
TF_RETURN_IF_ERROR(Instantiate(lib, else_func_, else_handle));
handles_[lib] = {*then_handle, *else_handle};
}
}
return Status::OK();
}
};
class CaseOp : public AsyncOpKernel {
@ -332,18 +360,10 @@ class WhileOp : public AsyncOpKernel {
auto lib = ctx->function_library();
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library"), done);
// TODO(b/37549631): Because this op has `SetIsStateful()` in its op
// registration, this kernel may be shared by multiple subgraphs, which have
// different associated `FunctionLibraryRuntime` objects and hence different
// `FHandle` namespaces. So we must call Instantiate() to make sure we get
// the correct function handles with respect to `lib`. Note the underlying
// `lib->Instantiate()` caches the created function handles, so calling
// `Instantiate()` repeatedly on the same `lib` and function is cheap.
FHandle cond_handle;
FHandle body_handle;
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done);
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done);
OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
done);
(new State(this, ctx, cond_handle, body_handle, done))->Start();
}
@ -351,6 +371,10 @@ class WhileOp : public AsyncOpKernel {
NameAttrList cond_func_;
NameAttrList body_func_;
mutex mu_;
std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
handles_ GUARDED_BY(mu_);
class State {
public:
State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
@ -486,6 +510,42 @@ class WhileOp : public AsyncOpKernel {
delete this;
}
};
Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle,
FHandle* body_handle) {
// TODO(b/37549631): Because this op has `SetIsStateful()` in its
// op registration, this kernel may be shared by multiple
// subgraphs, which have different associated
// `FunctionLibraryRuntime` objects and hence different `FHandle`
// namespaces. We currently work around this by caching the map
// from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
// functions this op uses.
auto lib = ctx->function_library();
if (lib == nullptr) return errors::Internal("No function library");
*cond_handle = kInvalidHandle;
*body_handle = kInvalidHandle;
{
tf_shared_lock l(mu_);
const auto iter = handles_.find(lib);
if (TF_PREDICT_TRUE(iter != handles_.end())) {
*cond_handle = iter->second.first;
*body_handle = iter->second.second;
}
}
if (TF_PREDICT_FALSE(*cond_handle == kInvalidHandle)) {
mutex_lock l(mu_);
const auto iter = handles_.find(lib);
if (TF_PREDICT_TRUE(iter != handles_.end())) {
*cond_handle = iter->second.first;
*body_handle = iter->second.second;
} else {
TF_RETURN_IF_ERROR(Instantiate(lib, cond_func_, cond_handle));
TF_RETURN_IF_ERROR(Instantiate(lib, body_func_, body_handle));
handles_[lib] = {*cond_handle, *body_handle};
}
}
return Status::OK();
}
};
// TODO(drpng): remove these.
REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);