Function calls inherit run_all_kernels_inline from their parent context.

The `Executor::Args::run_all_kernels_inline` flag optimizes the execution of graphs with many small kernels, and avoids potentially unbounded stack growth. This change enables functions called by a kernel to inherit this flag, which extends support for the option to larger and more complicated graphs containing function calls. It adds the flag to `OpKernelContext::Params` and `FunctionLibraryRuntime::Options`, and updates function-calling kernels to propagate it.

PiperOrigin-RevId: 298636976
Change-Id: I28263aa5a17ce7d94b84f6bb42657ce3f4b88cfa
This commit is contained in:
Derek Murray 2020-03-03 10:29:03 -08:00 committed by TensorFlower Gardener
parent 7072568ed6
commit 9a924476f3
12 changed files with 83 additions and 0 deletions

View File

@ -1737,6 +1737,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
params.inputs = &inputs;
params.input_alloc_attrs = &input_alloc_attrs;
params.runner = &runner_;
params.run_all_kernels_inline = run_all_kernels_inline_;
params.stats_collector = stats_collector_;
params.inc_num_deferred_ops_function = [this]() {
mutex_lock lock(num_deferred_ops_mu_);

View File

@ -532,6 +532,7 @@ class CallOp : public AsyncOpKernel {
opts.step_container = ctx->step_container();
opts.stats_collector = ctx->stats_collector();
opts.runner = ctx->runner();
opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
opts.collective_executor = ctx->collective_executor();
std::vector<Tensor> args;
args.reserve(ctx->num_inputs());
@ -1021,6 +1022,7 @@ void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
}
exec_args->collective_executor = run_opts.collective_executor;
exec_args->call_frame = frame;
exec_args->run_all_kernels_inline = run_opts.run_all_kernels_inline;
}
void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,

View File

@ -1872,6 +1872,67 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
TensorShape({})));
}
class AreAllKernelsInlineOp : public OpKernel {
public:
using OpKernel::OpKernel;
void Compute(OpKernelContext* ctx) override {
Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
output->scalar<bool>()() = ctx->run_all_kernels_inline();
}
};
REGISTER_OP("AreAllKernelsInline").Output("result : bool").SetIsStateful();
REGISTER_KERNEL_BUILDER(Name("AreAllKernelsInline").Device(DEVICE_CPU),
AreAllKernelsInlineOp);
TEST_F(FunctionLibraryRuntimeTest, RunAllKernelsInline) {
// Create a function "F" that includes an AreAllKernelsInline op, and a
// function "G" that calls "F".
auto f = FDH::Create(
// Name
"F",
// Args
{},
// Return values
{"ret: bool"},
// Attrs
{},
// Nodes
{// y = AreAllKernelsInline()
{{"y"}, "AreAllKernelsInline", {}, {}}},
{{"ret", "y:result:0"}});
auto g = FDH::Create(
// Name
"G",
// Args
{},
// Return values
{"ret: bool"},
// Attrs
{},
// Nodes
{// y = F()
{{"y"}, "F", {}, {}}},
{{"ret", "y:ret:0"}});
Init({f, g});
FunctionLibraryRuntime::Handle handle;
TF_CHECK_OK(Instantiate(flr0_, "G", {}, &handle));
// Test that the `run_all_kernels_inline` flag is inherited by the kernel
// running inside the called function.
for (bool inline_option : {false, true}) {
FunctionLibraryRuntime::Options opts;
opts.run_all_kernels_inline = inline_option;
Tensor result;
TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&result}, true));
EXPECT_EQ(result.scalar<bool>()(), inline_option);
}
}
namespace {
bool DoNothing(Graph* g) { return false; }

View File

@ -712,6 +712,10 @@ class FunctionLibraryRuntime {
// If True, allow returning dead tensors.
bool allow_dead_tensors = false;
// If True, hint that all kernels should be treated as "inexpensive", and
// hence executed on the scheduling thread.
bool run_all_kernels_inline = false;
// Returns a human readable representation of this.
string DebugString() const;
};

View File

@ -732,6 +732,7 @@ class OpKernelContext {
std::function<void(std::function<void()>)>* runner = nullptr;
StepStatsCollectorInterface* stats_collector = nullptr;
GraphCollector* graph_collector = nullptr;
bool run_all_kernels_inline = false;
// TensorSliceReaderCache support.
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
@ -867,6 +868,12 @@ class OpKernelContext {
// If non-null, kernels should populate with any partition subgraphs created.
GraphCollector* graph_collector() { return params_->graph_collector; }
// If True, hint that all kernels in functions called by this kernel, should
// be treated as "inexpensive", and hence executed on the scheduling thread.
bool run_all_kernels_inline() const {
return params_->run_all_kernels_inline;
}
// Input to output forwarding.
// Set the output Ref Tensor at output_index to be an alias of the

View File

@ -515,6 +515,7 @@ class BatchResource : public ResourceBase {
opts.stats_collector = last_task_context->stats_collector();
opts.rendezvous = last_task_context->rendezvous();
opts.runner = last_task_context->runner();
opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline();
auto* flib = last_task_context->function_library();
std::vector<Tensor> combined_outputs;

View File

@ -835,6 +835,7 @@ class OneShotIteratorOp : public AsyncOpKernel {
});
opts.step_container = &step_container;
opts.runner = ctx->runner();
opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
Notification n;
Status factory_status;
std::vector<Tensor> return_values;

View File

@ -239,6 +239,7 @@ void MapDefunOp::SetRunOptions(OpKernelContext* ctx,
} else {
opts->runner = ctx->runner();
}
opts->run_all_kernels_inline = ctx->run_all_kernels_inline();
}
Status MapDefunOp::SetupArgs(OpKernelContext* ctx,

View File

@ -259,6 +259,7 @@ class SingleThreadedExecutorImpl : public Executor {
Args::Runner runner_copy = args.runner;
params.runner = &runner_copy;
params.run_all_kernels_inline = args.run_all_kernels_inline;
params.stats_collector = args.stats_collector;
// NOTE(mrry): We are assuming that the graph is loopless and condless.

View File

@ -253,6 +253,7 @@ class SymbolicGradientOp : public AsyncOpKernel {
opts.rendezvous = ctx->rendezvous();
opts.cancellation_manager = ctx->cancellation_manager();
opts.runner = ctx->runner();
opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
opts.stats_collector = ctx->stats_collector();
opts.step_container = ctx->step_container();
opts.collective_executor = ctx->collective_executor();
@ -365,6 +366,7 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
FunctionLibraryRuntime::Options opts;
opts.runner = ctx->runner();
opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
opts.source_device = source_device;
if (opts.source_device != target_device) {
opts.remote_execution = true;

View File

@ -107,6 +107,7 @@ void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
opts->stats_collector = ctx->stats_collector();
}
opts->runner = ctx->runner();
opts->run_all_kernels_inline = ctx->run_all_kernels_inline();
opts->step_container = ctx->step_container();
}

View File

@ -241,6 +241,7 @@ void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle,
// TODO(akshayka): Consider selecting a runner on a per-device basis,
// i.e., using device-specific threadpools when available.
run_opts.runner = ctx->runner();
run_opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
run_opts.source_device =
lib->device() == nullptr ? "" : lib->device()->name();
run_opts.allow_dead_tensors = true;