[Functional ops] Move handle cache back into IfOp
and WhileOp
.
Previously, we were relying on the idempotency of `FunctionLibraryRuntime::Instantiate()` to cache the mapping from {FLR, Function} to FHandle. However, `Instantiate()` is a heavyweight method that acquires multiple exclusive locks, which causes contention when the kernel is invoked concurrently by many threads (e.g. during inference). By moving the cache (back, in the case of WhileOp) to the kernels, we can use a more appropriate `tf_shared_lock`, which reduces contention. PiperOrigin-RevId: 302923478 Change-Id: I9372c72fe34d98d390c970dbc5e033d204292e21
This commit is contained in:
parent
f1e0098f2a
commit
ac191b5591
@ -123,22 +123,10 @@ class IfOp : public AsyncOpKernel {
|
|||||||
~IfOp() override {}
|
~IfOp() override {}
|
||||||
|
|
||||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||||
auto lib = ctx->function_library();
|
|
||||||
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
|
|
||||||
errors::Internal("No function library"), done);
|
|
||||||
|
|
||||||
// TODO(b/37549631): Because this op has `SetIsStateful()` in its op
|
|
||||||
// registration, this kernel may be shared by multiple subgraphs, which have
|
|
||||||
// different associated `FunctionLibraryRuntime` objects and hence different
|
|
||||||
// `FHandle` namespaces. So we must call Instantiate() to make sure we get
|
|
||||||
// the correct function handles with respect to `lib`. Note the underlying
|
|
||||||
// `lib->Instantiate()` caches the created function handles, so calling
|
|
||||||
// `Instantiate()` repeatedly on the same `lib` and function is cheap.
|
|
||||||
FHandle then_handle;
|
FHandle then_handle;
|
||||||
FHandle else_handle;
|
FHandle else_handle;
|
||||||
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done);
|
OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &then_handle, &else_handle),
|
||||||
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done);
|
done);
|
||||||
|
|
||||||
bool cond;
|
bool cond;
|
||||||
OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
|
OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
|
||||||
(new State(this, ctx, cond, then_handle, else_handle, done))->Start();
|
(new State(this, ctx, cond, then_handle, else_handle, done))->Start();
|
||||||
@ -148,6 +136,10 @@ class IfOp : public AsyncOpKernel {
|
|||||||
NameAttrList then_func_;
|
NameAttrList then_func_;
|
||||||
NameAttrList else_func_;
|
NameAttrList else_func_;
|
||||||
|
|
||||||
|
mutex mu_;
|
||||||
|
std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
|
||||||
|
handles_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
class State {
|
class State {
|
||||||
public:
|
public:
|
||||||
State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
|
State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
|
||||||
@ -203,6 +195,42 @@ class IfOp : public AsyncOpKernel {
|
|||||||
TensorVec args_;
|
TensorVec args_;
|
||||||
TensorVec rets_;
|
TensorVec rets_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status GetHandles(OpKernelContext* ctx, FHandle* then_handle,
|
||||||
|
FHandle* else_handle) {
|
||||||
|
// TODO(b/37549631): Because this op has `SetIsStateful()` in its
|
||||||
|
// op registration, this kernel may be shared by multiple
|
||||||
|
// subgraphs, which have different associated
|
||||||
|
// `FunctionLibraryRuntime` objects and hence different `FHandle`
|
||||||
|
// namespaces. We currently work around this by caching the map
|
||||||
|
// from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
|
||||||
|
// functions this op uses.
|
||||||
|
auto lib = ctx->function_library();
|
||||||
|
if (lib == nullptr) return errors::Internal("No function library");
|
||||||
|
*then_handle = kInvalidHandle;
|
||||||
|
*else_handle = kInvalidHandle;
|
||||||
|
{
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
const auto iter = handles_.find(lib);
|
||||||
|
if (TF_PREDICT_TRUE(iter != handles_.end())) {
|
||||||
|
*then_handle = iter->second.first;
|
||||||
|
*else_handle = iter->second.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (TF_PREDICT_FALSE(*then_handle == kInvalidHandle)) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
const auto iter = handles_.find(lib);
|
||||||
|
if (TF_PREDICT_TRUE(iter != handles_.end())) {
|
||||||
|
*then_handle = iter->second.first;
|
||||||
|
*else_handle = iter->second.second;
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(Instantiate(lib, then_func_, then_handle));
|
||||||
|
TF_RETURN_IF_ERROR(Instantiate(lib, else_func_, else_handle));
|
||||||
|
handles_[lib] = {*then_handle, *else_handle};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class CaseOp : public AsyncOpKernel {
|
class CaseOp : public AsyncOpKernel {
|
||||||
@ -332,18 +360,10 @@ class WhileOp : public AsyncOpKernel {
|
|||||||
auto lib = ctx->function_library();
|
auto lib = ctx->function_library();
|
||||||
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
|
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
|
||||||
errors::Internal("No function library"), done);
|
errors::Internal("No function library"), done);
|
||||||
|
|
||||||
// TODO(b/37549631): Because this op has `SetIsStateful()` in its op
|
|
||||||
// registration, this kernel may be shared by multiple subgraphs, which have
|
|
||||||
// different associated `FunctionLibraryRuntime` objects and hence different
|
|
||||||
// `FHandle` namespaces. So we must call Instantiate() to make sure we get
|
|
||||||
// the correct function handles with respect to `lib`. Note the underlying
|
|
||||||
// `lib->Instantiate()` caches the created function handles, so calling
|
|
||||||
// `Instantiate()` repeatedly on the same `lib` and function is cheap.
|
|
||||||
FHandle cond_handle;
|
FHandle cond_handle;
|
||||||
FHandle body_handle;
|
FHandle body_handle;
|
||||||
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done);
|
OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
|
||||||
OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done);
|
done);
|
||||||
(new State(this, ctx, cond_handle, body_handle, done))->Start();
|
(new State(this, ctx, cond_handle, body_handle, done))->Start();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -351,6 +371,10 @@ class WhileOp : public AsyncOpKernel {
|
|||||||
NameAttrList cond_func_;
|
NameAttrList cond_func_;
|
||||||
NameAttrList body_func_;
|
NameAttrList body_func_;
|
||||||
|
|
||||||
|
mutex mu_;
|
||||||
|
std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
|
||||||
|
handles_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
class State {
|
class State {
|
||||||
public:
|
public:
|
||||||
State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
|
State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
|
||||||
@ -486,6 +510,42 @@ class WhileOp : public AsyncOpKernel {
|
|||||||
delete this;
|
delete this;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle,
|
||||||
|
FHandle* body_handle) {
|
||||||
|
// TODO(b/37549631): Because this op has `SetIsStateful()` in its
|
||||||
|
// op registration, this kernel may be shared by multiple
|
||||||
|
// subgraphs, which have different associated
|
||||||
|
// `FunctionLibraryRuntime` objects and hence different `FHandle`
|
||||||
|
// namespaces. We currently work around this by caching the map
|
||||||
|
// from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
|
||||||
|
// functions this op uses.
|
||||||
|
auto lib = ctx->function_library();
|
||||||
|
if (lib == nullptr) return errors::Internal("No function library");
|
||||||
|
*cond_handle = kInvalidHandle;
|
||||||
|
*body_handle = kInvalidHandle;
|
||||||
|
{
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
const auto iter = handles_.find(lib);
|
||||||
|
if (TF_PREDICT_TRUE(iter != handles_.end())) {
|
||||||
|
*cond_handle = iter->second.first;
|
||||||
|
*body_handle = iter->second.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (TF_PREDICT_FALSE(*cond_handle == kInvalidHandle)) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
const auto iter = handles_.find(lib);
|
||||||
|
if (TF_PREDICT_TRUE(iter != handles_.end())) {
|
||||||
|
*cond_handle = iter->second.first;
|
||||||
|
*body_handle = iter->second.second;
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(Instantiate(lib, cond_func_, cond_handle));
|
||||||
|
TF_RETURN_IF_ERROR(Instantiate(lib, body_func_, body_handle));
|
||||||
|
handles_[lib] = {*cond_handle, *body_handle};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// TODO(drpng): remove these.
|
// TODO(drpng): remove these.
|
||||||
REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);
|
REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user