[WhileV2] Fix potential stack overflow when kernels execute inline.
The previous implementation of the `WhileOp` uses callbacks to handle the result of the cond and body functions. While this is reasonable (and necessary) when the kernels execute on the inter-op threadpool, it can lead to stack overflow when inline execution is specified and the number of iterations is large. This change adds a non-callback-based version of `WhileOp` that will be dispatched to if either (i) the `OpKernelContext::run_all_kernels_inline()` method returns true, or (ii) the `WhileOp::Compute()` method is invoked. PiperOrigin-RevId: 309264480 Change-Id: I4697d6b080caca75a5c3f555063bea32a0de9987
This commit is contained in:
parent
f81dadf228
commit
8f82c20bb1
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -357,14 +358,30 @@ class WhileOp : public AsyncOpKernel {
|
|||||||
~WhileOp() override {}
|
~WhileOp() override {}
|
||||||
|
|
||||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||||
auto lib = ctx->function_library();
|
if (ctx->run_all_kernels_inline()) {
|
||||||
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
|
// Use the non-callback-based implementation when kernels (and function
|
||||||
errors::Internal("No function library"), done);
|
// callbacks) execute inline to avoid stack overflow.
|
||||||
FHandle cond_handle;
|
OP_REQUIRES_OK_ASYNC(ctx, DoComputeSync(ctx), done);
|
||||||
FHandle body_handle;
|
} else {
|
||||||
OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
|
FHandle cond_handle;
|
||||||
done);
|
FHandle body_handle;
|
||||||
(new State(this, ctx, cond_handle, body_handle, done))->Start();
|
OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
|
||||||
|
done);
|
||||||
|
(new State(this, ctx, cond_handle, body_handle, done))->Start();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
// Use the non-callback-based implementation when the synchronous Compute()
|
||||||
|
// method is invoked, because the caller is explicitly donating a thread.
|
||||||
|
Status s = DoComputeSync(ctx);
|
||||||
|
// NOTE: Unfortunately, we cannot use OP_REQUIRES_OK here, because this is
|
||||||
|
// still an AsyncOpKernel, and there is a run-time check to avoid calling
|
||||||
|
// OP_REQUIRES_OK in AsyncOpKernel::ComputeAsync() (which would deadlock in
|
||||||
|
// the event of an error).
|
||||||
|
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||||
|
ctx->SetStatus(s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -375,6 +392,41 @@ class WhileOp : public AsyncOpKernel {
|
|||||||
std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
|
std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
|
||||||
handles_ GUARDED_BY(mu_);
|
handles_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
static string EvalCondTraceString(
|
||||||
|
OpKernelContext* ctx, const FunctionLibraryRuntime::Options& opts) {
|
||||||
|
return absl::StrCat("WhileOp-EvalCond #parent_step_id=", ctx->step_id(),
|
||||||
|
",function_step_id=", opts.step_id, "#");
|
||||||
|
}
|
||||||
|
|
||||||
|
static string StartBodyTraceString(
|
||||||
|
OpKernelContext* ctx, const FunctionLibraryRuntime::Options& opts) {
|
||||||
|
return absl::StrCat("WhileOp-StartBody #parent_step_id=", ctx->step_id(),
|
||||||
|
",function_step_id=", opts.step_id, "#");
|
||||||
|
}
|
||||||
|
|
||||||
|
static Status CondResultToBool(OpKernelContext* ctx,
|
||||||
|
const FunctionLibraryRuntime::Options& opts,
|
||||||
|
const Tensor& cond_t, bool* out_result) {
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
const DeviceBase::GpuDeviceInfo* gpu_device_info =
|
||||||
|
ctx->device()->tensorflow_gpu_device_info();
|
||||||
|
const bool is_hostmem_dtype =
|
||||||
|
cond_t.dtype() == DT_INT32 || cond_t.dtype() == DT_INT64;
|
||||||
|
if (!is_hostmem_dtype && gpu_device_info &&
|
||||||
|
(opts.rets_alloc_attrs.empty() ||
|
||||||
|
!opts.rets_alloc_attrs[0].on_host())) {
|
||||||
|
// Copy the ret value to host if it's allocated on device.
|
||||||
|
Device* device = down_cast<Device*>(ctx->device());
|
||||||
|
DeviceContext* device_ctx = ctx->op_device_context();
|
||||||
|
Tensor host_cond_t = Tensor(cond_t.dtype(), cond_t.shape());
|
||||||
|
TF_RETURN_IF_ERROR(device_ctx->CopyDeviceTensorToCPUSync(
|
||||||
|
&cond_t, /*tensor_name=*/"", device, &host_cond_t));
|
||||||
|
return ToBool({host_cond_t}, out_result);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return ToBool({cond_t}, out_result);
|
||||||
|
}
|
||||||
|
|
||||||
class State {
|
class State {
|
||||||
public:
|
public:
|
||||||
State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
|
State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
|
||||||
@ -408,11 +460,7 @@ class WhileOp : public AsyncOpKernel {
|
|||||||
|
|
||||||
void EvalCond() {
|
void EvalCond() {
|
||||||
profiler::TraceMe trace_me(
|
profiler::TraceMe trace_me(
|
||||||
[&] {
|
[&] { return EvalCondTraceString(ctx_, opts_); },
|
||||||
return absl::StrCat(
|
|
||||||
"WhileOp-EvalCond #parent_step_id=", ctx_->step_id(),
|
|
||||||
",function_step_id=", opts_.step_id, "#");
|
|
||||||
},
|
|
||||||
/*level=*/2);
|
/*level=*/2);
|
||||||
lib_->Run(
|
lib_->Run(
|
||||||
// Evaluate the condition.
|
// Evaluate the condition.
|
||||||
@ -434,46 +482,22 @@ class WhileOp : public AsyncOpKernel {
|
|||||||
rets_.size(), " tensors.");
|
rets_.size(), " tensors.");
|
||||||
return Finish(s);
|
return Finish(s);
|
||||||
}
|
}
|
||||||
Tensor cond_t;
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
|
||||||
const DeviceBase::GpuDeviceInfo* gpu_device_info =
|
|
||||||
ctx_->device()->tensorflow_gpu_device_info();
|
|
||||||
const bool is_hostmem_dtype =
|
|
||||||
rets_[0].dtype() == DT_INT32 || rets_[0].dtype() == DT_INT64;
|
|
||||||
if (!is_hostmem_dtype && gpu_device_info &&
|
|
||||||
(opts_.rets_alloc_attrs.empty() ||
|
|
||||||
!opts_.rets_alloc_attrs[0].on_host())) {
|
|
||||||
// Copy the ret value to host if it's allocated on device.
|
|
||||||
Device* device = down_cast<Device*>(ctx_->device());
|
|
||||||
DeviceContext* device_ctx = ctx_->op_device_context();
|
|
||||||
cond_t = Tensor(rets_[0].dtype(), rets_[0].shape());
|
|
||||||
s = device_ctx->CopyDeviceTensorToCPUSync(&rets_[0], /*tensor_name=*/"",
|
|
||||||
device, &cond_t);
|
|
||||||
if (!s.ok()) {
|
|
||||||
return Finish(s);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cond_t = rets_[0];
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
cond_t = rets_[0];
|
|
||||||
#endif
|
|
||||||
bool cond;
|
|
||||||
s = ToBool({cond_t}, &cond);
|
|
||||||
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return Finish(s);
|
return Finish(s);
|
||||||
}
|
}
|
||||||
|
bool cond;
|
||||||
|
s = CondResultToBool(ctx_, opts_, rets_[0], &cond);
|
||||||
|
if (!s.ok()) {
|
||||||
|
return Finish(s);
|
||||||
|
}
|
||||||
|
|
||||||
if (!cond) {
|
if (!cond) {
|
||||||
return Finish(Status::OK());
|
return Finish(Status::OK());
|
||||||
}
|
}
|
||||||
rets_.clear();
|
rets_.clear();
|
||||||
profiler::TraceMe trace_me(
|
profiler::TraceMe trace_me(
|
||||||
[&] {
|
[&] { return StartBodyTraceString(ctx_, opts_); },
|
||||||
return absl::StrCat(
|
|
||||||
"WhileOp-StartBody #parent_step_id=", ctx_->step_id(),
|
|
||||||
",function_step_id=", opts_.step_id, "#");
|
|
||||||
},
|
|
||||||
/*level=*/2);
|
/*level=*/2);
|
||||||
lib_->Run(
|
lib_->Run(
|
||||||
// Evaluate the body.
|
// Evaluate the body.
|
||||||
@ -505,6 +529,68 @@ class WhileOp : public AsyncOpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status DoComputeSync(OpKernelContext* ctx) {
|
||||||
|
FHandle cond_handle;
|
||||||
|
FHandle body_handle;
|
||||||
|
TF_RETURN_IF_ERROR(GetHandles(ctx, &cond_handle, &body_handle));
|
||||||
|
auto lib = ctx->function_library();
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
SetRunOptions(ctx, &opts, false /* always_collect_stats */);
|
||||||
|
|
||||||
|
// Pre-allocate argument and return value vectors for the cond and body
|
||||||
|
// functions.
|
||||||
|
std::vector<Tensor> args;
|
||||||
|
const int num_loop_vars = ctx->num_inputs();
|
||||||
|
args.reserve(num_loop_vars);
|
||||||
|
std::vector<Tensor> cond_rets;
|
||||||
|
cond_rets.reserve(1);
|
||||||
|
std::vector<Tensor> body_rets;
|
||||||
|
body_rets.reserve(num_loop_vars);
|
||||||
|
|
||||||
|
// The initial loop variable args are the inputs to the kernel.
|
||||||
|
for (int i = 0; i < num_loop_vars; ++i) {
|
||||||
|
args.push_back(ctx->input(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the logic of the while loop as a single C++ do-while loop that
|
||||||
|
// executes the cond and body functions synchronously.
|
||||||
|
do {
|
||||||
|
// Evaluate the cond function on the current loop variables.
|
||||||
|
{
|
||||||
|
profiler::TraceMe trace_me(
|
||||||
|
[&] { return EvalCondTraceString(ctx, opts); },
|
||||||
|
/*level=*/2);
|
||||||
|
TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets));
|
||||||
|
}
|
||||||
|
if (cond_rets.size() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Expected a single scalar return value from WhileOp cond, got ",
|
||||||
|
cond_rets.size(), " tensors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the cond function evaluates to false, we are done: output the
|
||||||
|
// current loop variables.
|
||||||
|
bool cond_result;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CondResultToBool(ctx, opts, cond_rets[0], &cond_result));
|
||||||
|
if (!cond_result) {
|
||||||
|
return SetOutputs(this, ctx, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evaluate the body function on the current loop variables, to get an
|
||||||
|
// updated vector of loop variables.
|
||||||
|
{
|
||||||
|
profiler::TraceMe trace_me(
|
||||||
|
[&] { return StartBodyTraceString(ctx, opts); },
|
||||||
|
/*level=*/2);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, args, &body_rets));
|
||||||
|
}
|
||||||
|
std::swap(body_rets, args);
|
||||||
|
body_rets.clear();
|
||||||
|
} while (true);
|
||||||
|
}
|
||||||
|
|
||||||
Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle,
|
Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle,
|
||||||
FHandle* body_handle) {
|
FHandle* body_handle) {
|
||||||
// TODO(b/37549631): Because this op has `SetIsStateful()` in its
|
// TODO(b/37549631): Because this op has `SetIsStateful()` in its
|
||||||
|
@ -49,6 +49,7 @@ from tensorflow.python.ops import while_v2
|
|||||||
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
|
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
def random_gamma(shape): # pylint: disable=invalid-name
|
def random_gamma(shape): # pylint: disable=invalid-name
|
||||||
return random_ops.random_gamma(shape, 1.0)
|
return random_ops.random_gamma(shape, 1.0)
|
||||||
|
|
||||||
@ -1222,6 +1223,25 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertNotIn("switch", ns.node_name)
|
self.assertNotIn("switch", ns.node_name)
|
||||||
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old
|
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old
|
||||||
|
|
||||||
|
def _runBasicWithConfig(self, config):
|
||||||
|
with ops.device("/cpu:0"):
|
||||||
|
x = constant_op.constant(0)
|
||||||
|
ret, = while_loop_v2(lambda x: x < 1000, lambda x: x + 1, [x])
|
||||||
|
with self.cached_session(config=config):
|
||||||
|
self.assertEqual(1000, self.evaluate(ret))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testRunKernelsInline(self):
|
||||||
|
config = config_pb2.ConfigProto()
|
||||||
|
config.inter_op_parallelism_threads = -1
|
||||||
|
self._runBasicWithConfig(config)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testSingleThreadedExecution(self):
|
||||||
|
config = config_pb2.ConfigProto()
|
||||||
|
config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
|
||||||
|
self._runBasicWithConfig(config)
|
||||||
|
|
||||||
|
|
||||||
def ScalarShape():
|
def ScalarShape():
|
||||||
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
||||||
|
Loading…
Reference in New Issue
Block a user