Function calls inherit run_all_kernels_inline
from their parent context.
The `Executor::Args::run_all_kernels_inline` flag optimizes the execution of graphs with many small kernels, and avoids potentially unbounded stack growth. This change enables functions called by a kernel to inherit this flag, which extends support for the option to larger and more complicated graphs containing function calls. It adds the flag to `OpKernelContext::Params` and `FunctionLibraryRuntime::Options`, and updates function-calling kernels to propagate it. PiperOrigin-RevId: 298636976 Change-Id: I28263aa5a17ce7d94b84f6bb42657ce3f4b88cfa
This commit is contained in:
parent
7072568ed6
commit
9a924476f3
@ -1737,6 +1737,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
|
||||
params.inputs = &inputs;
|
||||
params.input_alloc_attrs = &input_alloc_attrs;
|
||||
params.runner = &runner_;
|
||||
params.run_all_kernels_inline = run_all_kernels_inline_;
|
||||
params.stats_collector = stats_collector_;
|
||||
params.inc_num_deferred_ops_function = [this]() {
|
||||
mutex_lock lock(num_deferred_ops_mu_);
|
||||
|
@ -532,6 +532,7 @@ class CallOp : public AsyncOpKernel {
|
||||
opts.step_container = ctx->step_container();
|
||||
opts.stats_collector = ctx->stats_collector();
|
||||
opts.runner = ctx->runner();
|
||||
opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
|
||||
opts.collective_executor = ctx->collective_executor();
|
||||
std::vector<Tensor> args;
|
||||
args.reserve(ctx->num_inputs());
|
||||
@ -1021,6 +1022,7 @@ void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
|
||||
}
|
||||
exec_args->collective_executor = run_opts.collective_executor;
|
||||
exec_args->call_frame = frame;
|
||||
exec_args->run_all_kernels_inline = run_opts.run_all_kernels_inline;
|
||||
}
|
||||
|
||||
void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
|
||||
|
@ -1872,6 +1872,67 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
||||
TensorShape({})));
|
||||
}
|
||||
|
||||
class AreAllKernelsInlineOp : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
|
||||
output->scalar<bool>()() = ctx->run_all_kernels_inline();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_OP("AreAllKernelsInline").Output("result : bool").SetIsStateful();
|
||||
REGISTER_KERNEL_BUILDER(Name("AreAllKernelsInline").Device(DEVICE_CPU),
|
||||
AreAllKernelsInlineOp);
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, RunAllKernelsInline) {
|
||||
// Create a function "F" that includes an AreAllKernelsInline op, and a
|
||||
// function "G" that calls "F".
|
||||
auto f = FDH::Create(
|
||||
// Name
|
||||
"F",
|
||||
// Args
|
||||
{},
|
||||
// Return values
|
||||
{"ret: bool"},
|
||||
// Attrs
|
||||
{},
|
||||
// Nodes
|
||||
{// y = AreAllKernelsInline()
|
||||
{{"y"}, "AreAllKernelsInline", {}, {}}},
|
||||
{{"ret", "y:result:0"}});
|
||||
|
||||
auto g = FDH::Create(
|
||||
// Name
|
||||
"G",
|
||||
// Args
|
||||
{},
|
||||
// Return values
|
||||
{"ret: bool"},
|
||||
// Attrs
|
||||
{},
|
||||
// Nodes
|
||||
{// y = F()
|
||||
{{"y"}, "F", {}, {}}},
|
||||
{{"ret", "y:ret:0"}});
|
||||
|
||||
Init({f, g});
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
TF_CHECK_OK(Instantiate(flr0_, "G", {}, &handle));
|
||||
|
||||
// Test that the `run_all_kernels_inline` flag is inherited by the kernel
|
||||
// running inside the called function.
|
||||
for (bool inline_option : {false, true}) {
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.run_all_kernels_inline = inline_option;
|
||||
Tensor result;
|
||||
TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&result}, true));
|
||||
EXPECT_EQ(result.scalar<bool>()(), inline_option);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool DoNothing(Graph* g) { return false; }
|
||||
|
@ -712,6 +712,10 @@ class FunctionLibraryRuntime {
|
||||
// If True, allow returning dead tensors.
|
||||
bool allow_dead_tensors = false;
|
||||
|
||||
// If True, hint that all kernels should be treated as "inexpensive", and
|
||||
// hence executed on the scheduling thread.
|
||||
bool run_all_kernels_inline = false;
|
||||
|
||||
// Returns a human readable representation of this.
|
||||
string DebugString() const;
|
||||
};
|
||||
|
@ -732,6 +732,7 @@ class OpKernelContext {
|
||||
std::function<void(std::function<void()>)>* runner = nullptr;
|
||||
StepStatsCollectorInterface* stats_collector = nullptr;
|
||||
GraphCollector* graph_collector = nullptr;
|
||||
bool run_all_kernels_inline = false;
|
||||
|
||||
// TensorSliceReaderCache support.
|
||||
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
|
||||
@ -867,6 +868,12 @@ class OpKernelContext {
|
||||
// If non-null, kernels should populate with any partition subgraphs created.
|
||||
GraphCollector* graph_collector() { return params_->graph_collector; }
|
||||
|
||||
// If True, hint that all kernels in functions called by this kernel, should
|
||||
// be treated as "inexpensive", and hence executed on the scheduling thread.
|
||||
bool run_all_kernels_inline() const {
|
||||
return params_->run_all_kernels_inline;
|
||||
}
|
||||
|
||||
// Input to output forwarding.
|
||||
|
||||
// Set the output Ref Tensor at output_index to be an alias of the
|
||||
|
@ -515,6 +515,7 @@ class BatchResource : public ResourceBase {
|
||||
opts.stats_collector = last_task_context->stats_collector();
|
||||
opts.rendezvous = last_task_context->rendezvous();
|
||||
opts.runner = last_task_context->runner();
|
||||
opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline();
|
||||
|
||||
auto* flib = last_task_context->function_library();
|
||||
std::vector<Tensor> combined_outputs;
|
||||
|
@ -835,6 +835,7 @@ class OneShotIteratorOp : public AsyncOpKernel {
|
||||
});
|
||||
opts.step_container = &step_container;
|
||||
opts.runner = ctx->runner();
|
||||
opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
|
||||
Notification n;
|
||||
Status factory_status;
|
||||
std::vector<Tensor> return_values;
|
||||
|
@ -239,6 +239,7 @@ void MapDefunOp::SetRunOptions(OpKernelContext* ctx,
|
||||
} else {
|
||||
opts->runner = ctx->runner();
|
||||
}
|
||||
opts->run_all_kernels_inline = ctx->run_all_kernels_inline();
|
||||
}
|
||||
|
||||
Status MapDefunOp::SetupArgs(OpKernelContext* ctx,
|
||||
|
@ -259,6 +259,7 @@ class SingleThreadedExecutorImpl : public Executor {
|
||||
|
||||
Args::Runner runner_copy = args.runner;
|
||||
params.runner = &runner_copy;
|
||||
params.run_all_kernels_inline = args.run_all_kernels_inline;
|
||||
params.stats_collector = args.stats_collector;
|
||||
|
||||
// NOTE(mrry): We are assuming that the graph is loopless and condless.
|
||||
|
@ -253,6 +253,7 @@ class SymbolicGradientOp : public AsyncOpKernel {
|
||||
opts.rendezvous = ctx->rendezvous();
|
||||
opts.cancellation_manager = ctx->cancellation_manager();
|
||||
opts.runner = ctx->runner();
|
||||
opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
|
||||
opts.stats_collector = ctx->stats_collector();
|
||||
opts.step_container = ctx->step_container();
|
||||
opts.collective_executor = ctx->collective_executor();
|
||||
@ -365,6 +366,7 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.runner = ctx->runner();
|
||||
opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
|
||||
opts.source_device = source_device;
|
||||
if (opts.source_device != target_device) {
|
||||
opts.remote_execution = true;
|
||||
|
@ -107,6 +107,7 @@ void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
|
||||
opts->stats_collector = ctx->stats_collector();
|
||||
}
|
||||
opts->runner = ctx->runner();
|
||||
opts->run_all_kernels_inline = ctx->run_all_kernels_inline();
|
||||
opts->step_container = ctx->step_container();
|
||||
}
|
||||
|
||||
|
@ -241,6 +241,7 @@ void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle,
|
||||
// TODO(akshayka): Consider selecting a runner on a per-device basis,
|
||||
// i.e., using device-specific threadpools when available.
|
||||
run_opts.runner = ctx->runner();
|
||||
run_opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
|
||||
run_opts.source_device =
|
||||
lib->device() == nullptr ? "" : lib->device()->name();
|
||||
run_opts.allow_dead_tensors = true;
|
||||
|
Loading…
Reference in New Issue
Block a user