diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index f6e6d8b1330..da8e92b5e74 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { @@ -357,14 +358,30 @@ class WhileOp : public AsyncOpKernel { ~WhileOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - auto lib = ctx->function_library(); - OP_REQUIRES_ASYNC(ctx, lib != nullptr, - errors::Internal("No function library"), done); - FHandle cond_handle; - FHandle body_handle; - OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle), - done); - (new State(this, ctx, cond_handle, body_handle, done))->Start(); + if (ctx->run_all_kernels_inline()) { + // Use the non-callback-based implementation when kernels (and function + // callbacks) execute inline to avoid stack overflow. + OP_REQUIRES_OK_ASYNC(ctx, DoComputeSync(ctx), done); + } else { + FHandle cond_handle; + FHandle body_handle; + OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle), + done); + (new State(this, ctx, cond_handle, body_handle, done))->Start(); + } + } + + void Compute(OpKernelContext* ctx) override { + // Use the non-callback-based implementation when the synchronous Compute() + // method is invoked, because the caller is explicitly donating a thread. + Status s = DoComputeSync(ctx); + // NOTE: Unfortunately, we cannot use OP_REQUIRES_OK here, because this is + // still an AsyncOpKernel, and there is a run-time check to avoid calling + // OP_REQUIRES_OK in AsyncOpKernel::ComputeAsync() (which would deadlock in + // the event of an error). + if (TF_PREDICT_FALSE(!s.ok())) { + ctx->SetStatus(s); + } } private: @@ -375,6 +392,41 @@ class WhileOp : public AsyncOpKernel { std::unordered_map> handles_ GUARDED_BY(mu_); + static string EvalCondTraceString( + OpKernelContext* ctx, const FunctionLibraryRuntime::Options& opts) { + return absl::StrCat("WhileOp-EvalCond #parent_step_id=", ctx->step_id(), + ",function_step_id=", opts.step_id, "#"); + } + + static string StartBodyTraceString( + OpKernelContext* ctx, const FunctionLibraryRuntime::Options& opts) { + return absl::StrCat("WhileOp-StartBody #parent_step_id=", ctx->step_id(), + ",function_step_id=", opts.step_id, "#"); + } + + static Status CondResultToBool(OpKernelContext* ctx, + const FunctionLibraryRuntime::Options& opts, + const Tensor& cond_t, bool* out_result) { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + const DeviceBase::GpuDeviceInfo* gpu_device_info = + ctx->device()->tensorflow_gpu_device_info(); + const bool is_hostmem_dtype = + cond_t.dtype() == DT_INT32 || cond_t.dtype() == DT_INT64; + if (!is_hostmem_dtype && gpu_device_info && + (opts.rets_alloc_attrs.empty() || + !opts.rets_alloc_attrs[0].on_host())) { + // Copy the ret value to host if it's allocated on device. + Device* device = down_cast(ctx->device()); + DeviceContext* device_ctx = ctx->op_device_context(); + Tensor host_cond_t = Tensor(cond_t.dtype(), cond_t.shape()); + TF_RETURN_IF_ERROR(device_ctx->CopyDeviceTensorToCPUSync( + &cond_t, /*tensor_name=*/"", device, &host_cond_t)); + return ToBool({host_cond_t}, out_result); + } +#endif + return ToBool({cond_t}, out_result); + } + class State { public: State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, @@ -408,11 +460,7 @@ class WhileOp : public AsyncOpKernel { void EvalCond() { profiler::TraceMe trace_me( - [&] { - return absl::StrCat( - "WhileOp-EvalCond #parent_step_id=", ctx_->step_id(), - ",function_step_id=", opts_.step_id, "#"); - }, + [&] { return EvalCondTraceString(ctx_, opts_); }, /*level=*/2); lib_->Run( // Evaluate the condition. @@ -434,46 +482,22 @@ class WhileOp : public AsyncOpKernel { rets_.size(), " tensors."); return Finish(s); } - Tensor cond_t; -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - const DeviceBase::GpuDeviceInfo* gpu_device_info = - ctx_->device()->tensorflow_gpu_device_info(); - const bool is_hostmem_dtype = - rets_[0].dtype() == DT_INT32 || rets_[0].dtype() == DT_INT64; - if (!is_hostmem_dtype && gpu_device_info && - (opts_.rets_alloc_attrs.empty() || - !opts_.rets_alloc_attrs[0].on_host())) { - // Copy the ret value to host if it's allocated on device. - Device* device = down_cast(ctx_->device()); - DeviceContext* device_ctx = ctx_->op_device_context(); - cond_t = Tensor(rets_[0].dtype(), rets_[0].shape()); - s = device_ctx->CopyDeviceTensorToCPUSync(&rets_[0], /*tensor_name=*/"", - device, &cond_t); - if (!s.ok()) { - return Finish(s); - } - } else { - cond_t = rets_[0]; - } -#else - cond_t = rets_[0]; -#endif - bool cond; - s = ToBool({cond_t}, &cond); if (!s.ok()) { return Finish(s); } + bool cond; + s = CondResultToBool(ctx_, opts_, rets_[0], &cond); + if (!s.ok()) { + return Finish(s); + } + if (!cond) { return Finish(Status::OK()); } rets_.clear(); profiler::TraceMe trace_me( - [&] { - return absl::StrCat( - "WhileOp-StartBody #parent_step_id=", ctx_->step_id(), - ",function_step_id=", opts_.step_id, "#"); - }, + [&] { return StartBodyTraceString(ctx_, opts_); }, /*level=*/2); lib_->Run( // Evaluate the body. @@ -505,6 +529,68 @@ class WhileOp : public AsyncOpKernel { } }; + Status DoComputeSync(OpKernelContext* ctx) { + FHandle cond_handle; + FHandle body_handle; + TF_RETURN_IF_ERROR(GetHandles(ctx, &cond_handle, &body_handle)); + auto lib = ctx->function_library(); + FunctionLibraryRuntime::Options opts; + SetRunOptions(ctx, &opts, false /* always_collect_stats */); + + // Pre-allocate argument and return value vectors for the cond and body + // functions. + std::vector args; + const int num_loop_vars = ctx->num_inputs(); + args.reserve(num_loop_vars); + std::vector cond_rets; + cond_rets.reserve(1); + std::vector body_rets; + body_rets.reserve(num_loop_vars); + + // The initial loop variable args are the inputs to the kernel. + for (int i = 0; i < num_loop_vars; ++i) { + args.push_back(ctx->input(i)); + } + + // Implement the logic of the while loop as a single C++ do-while loop that + // executes the cond and body functions synchronously. + do { + // Evaluate the cond function on the current loop variables. + { + profiler::TraceMe trace_me( + [&] { return EvalCondTraceString(ctx, opts); }, + /*level=*/2); + TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets)); + } + if (cond_rets.size() != 1) { + return errors::InvalidArgument( + "Expected a single scalar return value from WhileOp cond, got ", + cond_rets.size(), " tensors."); + } + + // If the cond function evaluates to false, we are done: output the + // current loop variables. + bool cond_result; + TF_RETURN_IF_ERROR( + CondResultToBool(ctx, opts, cond_rets[0], &cond_result)); + if (!cond_result) { + return SetOutputs(this, ctx, args); + } + + // Evaluate the body function on the current loop variables, to get an + // updated vector of loop variables. + { + profiler::TraceMe trace_me( + [&] { return StartBodyTraceString(ctx, opts); }, + /*level=*/2); + + TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, args, &body_rets)); + } + std::swap(body_rets, args); + body_rets.clear(); + } while (true); + } + Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle, FHandle* body_handle) { // TODO(b/37549631): Because this op has `SetIsStateful()` in its diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index b1e5957599e..95bbea156f2 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -49,6 +49,7 @@ from tensorflow.python.ops import while_v2 from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2 from tensorflow.python.platform import test + def random_gamma(shape): # pylint: disable=invalid-name return random_ops.random_gamma(shape, 1.0) @@ -1222,6 +1223,25 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): self.assertNotIn("switch", ns.node_name) control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old + def _runBasicWithConfig(self, config): + with ops.device("/cpu:0"): + x = constant_op.constant(0) + ret, = while_loop_v2(lambda x: x < 1000, lambda x: x + 1, [x]) + with self.cached_session(config=config): + self.assertEqual(1000, self.evaluate(ret)) + + @test_util.run_deprecated_v1 + def testRunKernelsInline(self): + config = config_pb2.ConfigProto() + config.inter_op_parallelism_threads = -1 + self._runBasicWithConfig(config) + + @test_util.run_deprecated_v1 + def testSingleThreadedExecution(self): + config = config_pb2.ConfigProto() + config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR" + self._runBasicWithConfig(config) + def ScalarShape(): return ops.convert_to_tensor([], dtype=dtypes.int32)