[WhileV2] Fix potential stack overflow when kernels execute inline.

The previous implementation of the `WhileOp` uses callbacks to handle the
result of the cond and body functions. While this is reasonable (and necessary)
when the kernels execute on the inter-op threadpool, it can lead to stack
overflow when inline execution is specified and the number of iterations is
large.

This change adds a non-callback-based version of `WhileOp` that will be
dispatched to if either (i) the `OpKernelContext::run_all_kernels_inline()`
method returns true, or (ii) the `WhileOp::Compute()` method is invoked.

PiperOrigin-RevId: 309264480
Change-Id: I4697d6b080caca75a5c3f555063bea32a0de9987
This commit is contained in:
Derek Murray 2020-04-30 11:24:04 -07:00 committed by TensorFlower Gardener
parent f81dadf228
commit 8f82c20bb1
2 changed files with 150 additions and 44 deletions

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow { namespace tensorflow {
@ -357,14 +358,30 @@ class WhileOp : public AsyncOpKernel {
~WhileOp() override {} ~WhileOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
auto lib = ctx->function_library(); if (ctx->run_all_kernels_inline()) {
OP_REQUIRES_ASYNC(ctx, lib != nullptr, // Use the non-callback-based implementation when kernels (and function
errors::Internal("No function library"), done); // callbacks) execute inline to avoid stack overflow.
FHandle cond_handle; OP_REQUIRES_OK_ASYNC(ctx, DoComputeSync(ctx), done);
FHandle body_handle; } else {
OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle), FHandle cond_handle;
done); FHandle body_handle;
(new State(this, ctx, cond_handle, body_handle, done))->Start(); OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
done);
(new State(this, ctx, cond_handle, body_handle, done))->Start();
}
}
void Compute(OpKernelContext* ctx) override {
// Use the non-callback-based implementation when the synchronous Compute()
// method is invoked, because the caller is explicitly donating a thread.
Status s = DoComputeSync(ctx);
// NOTE: Unfortunately, we cannot use OP_REQUIRES_OK here, because this is
// still an AsyncOpKernel, and there is a run-time check to avoid calling
// OP_REQUIRES_OK in AsyncOpKernel::ComputeAsync() (which would deadlock in
// the event of an error).
if (TF_PREDICT_FALSE(!s.ok())) {
ctx->SetStatus(s);
}
} }
private: private:
@ -375,6 +392,41 @@ class WhileOp : public AsyncOpKernel {
std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>> std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
handles_ GUARDED_BY(mu_); handles_ GUARDED_BY(mu_);
static string EvalCondTraceString(
OpKernelContext* ctx, const FunctionLibraryRuntime::Options& opts) {
return absl::StrCat("WhileOp-EvalCond #parent_step_id=", ctx->step_id(),
",function_step_id=", opts.step_id, "#");
}
static string StartBodyTraceString(
OpKernelContext* ctx, const FunctionLibraryRuntime::Options& opts) {
return absl::StrCat("WhileOp-StartBody #parent_step_id=", ctx->step_id(),
",function_step_id=", opts.step_id, "#");
}
static Status CondResultToBool(OpKernelContext* ctx,
const FunctionLibraryRuntime::Options& opts,
const Tensor& cond_t, bool* out_result) {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const DeviceBase::GpuDeviceInfo* gpu_device_info =
ctx->device()->tensorflow_gpu_device_info();
const bool is_hostmem_dtype =
cond_t.dtype() == DT_INT32 || cond_t.dtype() == DT_INT64;
if (!is_hostmem_dtype && gpu_device_info &&
(opts.rets_alloc_attrs.empty() ||
!opts.rets_alloc_attrs[0].on_host())) {
// Copy the ret value to host if it's allocated on device.
Device* device = down_cast<Device*>(ctx->device());
DeviceContext* device_ctx = ctx->op_device_context();
Tensor host_cond_t = Tensor(cond_t.dtype(), cond_t.shape());
TF_RETURN_IF_ERROR(device_ctx->CopyDeviceTensorToCPUSync(
&cond_t, /*tensor_name=*/"", device, &host_cond_t));
return ToBool({host_cond_t}, out_result);
}
#endif
return ToBool({cond_t}, out_result);
}
class State { class State {
public: public:
State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
@ -408,11 +460,7 @@ class WhileOp : public AsyncOpKernel {
void EvalCond() { void EvalCond() {
profiler::TraceMe trace_me( profiler::TraceMe trace_me(
[&] { [&] { return EvalCondTraceString(ctx_, opts_); },
return absl::StrCat(
"WhileOp-EvalCond #parent_step_id=", ctx_->step_id(),
",function_step_id=", opts_.step_id, "#");
},
/*level=*/2); /*level=*/2);
lib_->Run( lib_->Run(
// Evaluate the condition. // Evaluate the condition.
@ -434,46 +482,22 @@ class WhileOp : public AsyncOpKernel {
rets_.size(), " tensors."); rets_.size(), " tensors.");
return Finish(s); return Finish(s);
} }
Tensor cond_t;
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const DeviceBase::GpuDeviceInfo* gpu_device_info =
ctx_->device()->tensorflow_gpu_device_info();
const bool is_hostmem_dtype =
rets_[0].dtype() == DT_INT32 || rets_[0].dtype() == DT_INT64;
if (!is_hostmem_dtype && gpu_device_info &&
(opts_.rets_alloc_attrs.empty() ||
!opts_.rets_alloc_attrs[0].on_host())) {
// Copy the ret value to host if it's allocated on device.
Device* device = down_cast<Device*>(ctx_->device());
DeviceContext* device_ctx = ctx_->op_device_context();
cond_t = Tensor(rets_[0].dtype(), rets_[0].shape());
s = device_ctx->CopyDeviceTensorToCPUSync(&rets_[0], /*tensor_name=*/"",
device, &cond_t);
if (!s.ok()) {
return Finish(s);
}
} else {
cond_t = rets_[0];
}
#else
cond_t = rets_[0];
#endif
bool cond;
s = ToBool({cond_t}, &cond);
if (!s.ok()) { if (!s.ok()) {
return Finish(s); return Finish(s);
} }
bool cond;
s = CondResultToBool(ctx_, opts_, rets_[0], &cond);
if (!s.ok()) {
return Finish(s);
}
if (!cond) { if (!cond) {
return Finish(Status::OK()); return Finish(Status::OK());
} }
rets_.clear(); rets_.clear();
profiler::TraceMe trace_me( profiler::TraceMe trace_me(
[&] { [&] { return StartBodyTraceString(ctx_, opts_); },
return absl::StrCat(
"WhileOp-StartBody #parent_step_id=", ctx_->step_id(),
",function_step_id=", opts_.step_id, "#");
},
/*level=*/2); /*level=*/2);
lib_->Run( lib_->Run(
// Evaluate the body. // Evaluate the body.
@ -505,6 +529,68 @@ class WhileOp : public AsyncOpKernel {
} }
}; };
Status DoComputeSync(OpKernelContext* ctx) {
FHandle cond_handle;
FHandle body_handle;
TF_RETURN_IF_ERROR(GetHandles(ctx, &cond_handle, &body_handle));
auto lib = ctx->function_library();
FunctionLibraryRuntime::Options opts;
SetRunOptions(ctx, &opts, false /* always_collect_stats */);
// Pre-allocate argument and return value vectors for the cond and body
// functions.
std::vector<Tensor> args;
const int num_loop_vars = ctx->num_inputs();
args.reserve(num_loop_vars);
std::vector<Tensor> cond_rets;
cond_rets.reserve(1);
std::vector<Tensor> body_rets;
body_rets.reserve(num_loop_vars);
// The initial loop variable args are the inputs to the kernel.
for (int i = 0; i < num_loop_vars; ++i) {
args.push_back(ctx->input(i));
}
// Implement the logic of the while loop as a single C++ do-while loop that
// executes the cond and body functions synchronously.
do {
// Evaluate the cond function on the current loop variables.
{
profiler::TraceMe trace_me(
[&] { return EvalCondTraceString(ctx, opts); },
/*level=*/2);
TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets));
}
if (cond_rets.size() != 1) {
return errors::InvalidArgument(
"Expected a single scalar return value from WhileOp cond, got ",
cond_rets.size(), " tensors.");
}
// If the cond function evaluates to false, we are done: output the
// current loop variables.
bool cond_result;
TF_RETURN_IF_ERROR(
CondResultToBool(ctx, opts, cond_rets[0], &cond_result));
if (!cond_result) {
return SetOutputs(this, ctx, args);
}
// Evaluate the body function on the current loop variables, to get an
// updated vector of loop variables.
{
profiler::TraceMe trace_me(
[&] { return StartBodyTraceString(ctx, opts); },
/*level=*/2);
TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, args, &body_rets));
}
std::swap(body_rets, args);
body_rets.clear();
} while (true);
}
Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle, Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle,
FHandle* body_handle) { FHandle* body_handle) {
// TODO(b/37549631): Because this op has `SetIsStateful()` in its // TODO(b/37549631): Because this op has `SetIsStateful()` in its

View File

@ -49,6 +49,7 @@ from tensorflow.python.ops import while_v2
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2 from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
from tensorflow.python.platform import test from tensorflow.python.platform import test
def random_gamma(shape): # pylint: disable=invalid-name def random_gamma(shape): # pylint: disable=invalid-name
return random_ops.random_gamma(shape, 1.0) return random_ops.random_gamma(shape, 1.0)
@ -1222,6 +1223,25 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
self.assertNotIn("switch", ns.node_name) self.assertNotIn("switch", ns.node_name)
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old
def _runBasicWithConfig(self, config):
with ops.device("/cpu:0"):
x = constant_op.constant(0)
ret, = while_loop_v2(lambda x: x < 1000, lambda x: x + 1, [x])
with self.cached_session(config=config):
self.assertEqual(1000, self.evaluate(ret))
@test_util.run_deprecated_v1
def testRunKernelsInline(self):
config = config_pb2.ConfigProto()
config.inter_op_parallelism_threads = -1
self._runBasicWithConfig(config)
@test_util.run_deprecated_v1
def testSingleThreadedExecution(self):
config = config_pb2.ConfigProto()
config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
self._runBasicWithConfig(config)
def ScalarShape(): def ScalarShape():
return ops.convert_to_tensor([], dtype=dtypes.int32) return ops.convert_to_tensor([], dtype=dtypes.int32)