[WhileV2] Fix potential stack overflow when kernels execute inline.
The previous implementation of the `WhileOp` uses callbacks to handle the result of the cond and body functions. While this is reasonable (and necessary) when the kernels execute on the inter-op threadpool, it can lead to stack overflow when inline execution is specified and the number of iterations is large. This change adds a non-callback-based version of `WhileOp` that will be dispatched to if either (i) the `OpKernelContext::run_all_kernels_inline()` method returns true, or (ii) the `WhileOp::Compute()` method is invoked. PiperOrigin-RevId: 309264480 Change-Id: I4697d6b080caca75a5c3f555063bea32a0de9987
This commit is contained in:
parent
f81dadf228
commit
8f82c20bb1
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -357,14 +358,30 @@ class WhileOp : public AsyncOpKernel {
|
||||
~WhileOp() override {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
auto lib = ctx->function_library();
|
||||
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
|
||||
errors::Internal("No function library"), done);
|
||||
FHandle cond_handle;
|
||||
FHandle body_handle;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
|
||||
done);
|
||||
(new State(this, ctx, cond_handle, body_handle, done))->Start();
|
||||
if (ctx->run_all_kernels_inline()) {
|
||||
// Use the non-callback-based implementation when kernels (and function
|
||||
// callbacks) execute inline to avoid stack overflow.
|
||||
OP_REQUIRES_OK_ASYNC(ctx, DoComputeSync(ctx), done);
|
||||
} else {
|
||||
FHandle cond_handle;
|
||||
FHandle body_handle;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
|
||||
done);
|
||||
(new State(this, ctx, cond_handle, body_handle, done))->Start();
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// Use the non-callback-based implementation when the synchronous Compute()
|
||||
// method is invoked, because the caller is explicitly donating a thread.
|
||||
Status s = DoComputeSync(ctx);
|
||||
// NOTE: Unfortunately, we cannot use OP_REQUIRES_OK here, because this is
|
||||
// still an AsyncOpKernel, and there is a run-time check to avoid calling
|
||||
// OP_REQUIRES_OK in AsyncOpKernel::ComputeAsync() (which would deadlock in
|
||||
// the event of an error).
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
ctx->SetStatus(s);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -375,6 +392,41 @@ class WhileOp : public AsyncOpKernel {
|
||||
std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
|
||||
handles_ GUARDED_BY(mu_);
|
||||
|
||||
static string EvalCondTraceString(
|
||||
OpKernelContext* ctx, const FunctionLibraryRuntime::Options& opts) {
|
||||
return absl::StrCat("WhileOp-EvalCond #parent_step_id=", ctx->step_id(),
|
||||
",function_step_id=", opts.step_id, "#");
|
||||
}
|
||||
|
||||
static string StartBodyTraceString(
|
||||
OpKernelContext* ctx, const FunctionLibraryRuntime::Options& opts) {
|
||||
return absl::StrCat("WhileOp-StartBody #parent_step_id=", ctx->step_id(),
|
||||
",function_step_id=", opts.step_id, "#");
|
||||
}
|
||||
|
||||
static Status CondResultToBool(OpKernelContext* ctx,
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
const Tensor& cond_t, bool* out_result) {
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
const DeviceBase::GpuDeviceInfo* gpu_device_info =
|
||||
ctx->device()->tensorflow_gpu_device_info();
|
||||
const bool is_hostmem_dtype =
|
||||
cond_t.dtype() == DT_INT32 || cond_t.dtype() == DT_INT64;
|
||||
if (!is_hostmem_dtype && gpu_device_info &&
|
||||
(opts.rets_alloc_attrs.empty() ||
|
||||
!opts.rets_alloc_attrs[0].on_host())) {
|
||||
// Copy the ret value to host if it's allocated on device.
|
||||
Device* device = down_cast<Device*>(ctx->device());
|
||||
DeviceContext* device_ctx = ctx->op_device_context();
|
||||
Tensor host_cond_t = Tensor(cond_t.dtype(), cond_t.shape());
|
||||
TF_RETURN_IF_ERROR(device_ctx->CopyDeviceTensorToCPUSync(
|
||||
&cond_t, /*tensor_name=*/"", device, &host_cond_t));
|
||||
return ToBool({host_cond_t}, out_result);
|
||||
}
|
||||
#endif
|
||||
return ToBool({cond_t}, out_result);
|
||||
}
|
||||
|
||||
class State {
|
||||
public:
|
||||
State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
|
||||
@ -408,11 +460,7 @@ class WhileOp : public AsyncOpKernel {
|
||||
|
||||
void EvalCond() {
|
||||
profiler::TraceMe trace_me(
|
||||
[&] {
|
||||
return absl::StrCat(
|
||||
"WhileOp-EvalCond #parent_step_id=", ctx_->step_id(),
|
||||
",function_step_id=", opts_.step_id, "#");
|
||||
},
|
||||
[&] { return EvalCondTraceString(ctx_, opts_); },
|
||||
/*level=*/2);
|
||||
lib_->Run(
|
||||
// Evaluate the condition.
|
||||
@ -434,46 +482,22 @@ class WhileOp : public AsyncOpKernel {
|
||||
rets_.size(), " tensors.");
|
||||
return Finish(s);
|
||||
}
|
||||
Tensor cond_t;
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
const DeviceBase::GpuDeviceInfo* gpu_device_info =
|
||||
ctx_->device()->tensorflow_gpu_device_info();
|
||||
const bool is_hostmem_dtype =
|
||||
rets_[0].dtype() == DT_INT32 || rets_[0].dtype() == DT_INT64;
|
||||
if (!is_hostmem_dtype && gpu_device_info &&
|
||||
(opts_.rets_alloc_attrs.empty() ||
|
||||
!opts_.rets_alloc_attrs[0].on_host())) {
|
||||
// Copy the ret value to host if it's allocated on device.
|
||||
Device* device = down_cast<Device*>(ctx_->device());
|
||||
DeviceContext* device_ctx = ctx_->op_device_context();
|
||||
cond_t = Tensor(rets_[0].dtype(), rets_[0].shape());
|
||||
s = device_ctx->CopyDeviceTensorToCPUSync(&rets_[0], /*tensor_name=*/"",
|
||||
device, &cond_t);
|
||||
if (!s.ok()) {
|
||||
return Finish(s);
|
||||
}
|
||||
} else {
|
||||
cond_t = rets_[0];
|
||||
}
|
||||
#else
|
||||
cond_t = rets_[0];
|
||||
#endif
|
||||
bool cond;
|
||||
s = ToBool({cond_t}, &cond);
|
||||
|
||||
if (!s.ok()) {
|
||||
return Finish(s);
|
||||
}
|
||||
bool cond;
|
||||
s = CondResultToBool(ctx_, opts_, rets_[0], &cond);
|
||||
if (!s.ok()) {
|
||||
return Finish(s);
|
||||
}
|
||||
|
||||
if (!cond) {
|
||||
return Finish(Status::OK());
|
||||
}
|
||||
rets_.clear();
|
||||
profiler::TraceMe trace_me(
|
||||
[&] {
|
||||
return absl::StrCat(
|
||||
"WhileOp-StartBody #parent_step_id=", ctx_->step_id(),
|
||||
",function_step_id=", opts_.step_id, "#");
|
||||
},
|
||||
[&] { return StartBodyTraceString(ctx_, opts_); },
|
||||
/*level=*/2);
|
||||
lib_->Run(
|
||||
// Evaluate the body.
|
||||
@ -505,6 +529,68 @@ class WhileOp : public AsyncOpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
Status DoComputeSync(OpKernelContext* ctx) {
|
||||
FHandle cond_handle;
|
||||
FHandle body_handle;
|
||||
TF_RETURN_IF_ERROR(GetHandles(ctx, &cond_handle, &body_handle));
|
||||
auto lib = ctx->function_library();
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
SetRunOptions(ctx, &opts, false /* always_collect_stats */);
|
||||
|
||||
// Pre-allocate argument and return value vectors for the cond and body
|
||||
// functions.
|
||||
std::vector<Tensor> args;
|
||||
const int num_loop_vars = ctx->num_inputs();
|
||||
args.reserve(num_loop_vars);
|
||||
std::vector<Tensor> cond_rets;
|
||||
cond_rets.reserve(1);
|
||||
std::vector<Tensor> body_rets;
|
||||
body_rets.reserve(num_loop_vars);
|
||||
|
||||
// The initial loop variable args are the inputs to the kernel.
|
||||
for (int i = 0; i < num_loop_vars; ++i) {
|
||||
args.push_back(ctx->input(i));
|
||||
}
|
||||
|
||||
// Implement the logic of the while loop as a single C++ do-while loop that
|
||||
// executes the cond and body functions synchronously.
|
||||
do {
|
||||
// Evaluate the cond function on the current loop variables.
|
||||
{
|
||||
profiler::TraceMe trace_me(
|
||||
[&] { return EvalCondTraceString(ctx, opts); },
|
||||
/*level=*/2);
|
||||
TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets));
|
||||
}
|
||||
if (cond_rets.size() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected a single scalar return value from WhileOp cond, got ",
|
||||
cond_rets.size(), " tensors.");
|
||||
}
|
||||
|
||||
// If the cond function evaluates to false, we are done: output the
|
||||
// current loop variables.
|
||||
bool cond_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CondResultToBool(ctx, opts, cond_rets[0], &cond_result));
|
||||
if (!cond_result) {
|
||||
return SetOutputs(this, ctx, args);
|
||||
}
|
||||
|
||||
// Evaluate the body function on the current loop variables, to get an
|
||||
// updated vector of loop variables.
|
||||
{
|
||||
profiler::TraceMe trace_me(
|
||||
[&] { return StartBodyTraceString(ctx, opts); },
|
||||
/*level=*/2);
|
||||
|
||||
TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, args, &body_rets));
|
||||
}
|
||||
std::swap(body_rets, args);
|
||||
body_rets.clear();
|
||||
} while (true);
|
||||
}
|
||||
|
||||
Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle,
|
||||
FHandle* body_handle) {
|
||||
// TODO(b/37549631): Because this op has `SetIsStateful()` in its
|
||||
|
@ -49,6 +49,7 @@ from tensorflow.python.ops import while_v2
|
||||
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def random_gamma(shape): # pylint: disable=invalid-name
|
||||
return random_ops.random_gamma(shape, 1.0)
|
||||
|
||||
@ -1222,6 +1223,25 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertNotIn("switch", ns.node_name)
|
||||
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old
|
||||
|
||||
def _runBasicWithConfig(self, config):
|
||||
with ops.device("/cpu:0"):
|
||||
x = constant_op.constant(0)
|
||||
ret, = while_loop_v2(lambda x: x < 1000, lambda x: x + 1, [x])
|
||||
with self.cached_session(config=config):
|
||||
self.assertEqual(1000, self.evaluate(ret))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRunKernelsInline(self):
|
||||
config = config_pb2.ConfigProto()
|
||||
config.inter_op_parallelism_threads = -1
|
||||
self._runBasicWithConfig(config)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleThreadedExecution(self):
|
||||
config = config_pb2.ConfigProto()
|
||||
config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
|
||||
self._runBasicWithConfig(config)
|
||||
|
||||
|
||||
def ScalarShape():
|
||||
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
||||
|
Loading…
Reference in New Issue
Block a user