From 98ab062a8b931f226cbaa337809cf5e9fd9e6f20 Mon Sep 17 00:00:00 2001 From: Brian Zhao Date: Wed, 20 Nov 2019 14:51:25 -0800 Subject: [PATCH] Automated g4 rollback of changelist 281562211. PiperOrigin-RevId: 281607574 Change-Id: I7d06715ba20f527db4c2f5bea61a51e04b7816a1 --- .../core/common_runtime/direct_session.cc | 208 ++++++++---------- tensorflow/core/common_runtime/executor.h | 2 +- .../kernels/data/single_threaded_executor.cc | 20 +- 3 files changed, 108 insertions(+), 122 deletions(-) diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 30a1e2a9ea4..001ee7d1f19 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -497,7 +497,6 @@ Status DirectSession::RunInternal( const uint64 start_time_usecs = options_.env->NowMicros(); const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); RunState run_state(step_id, &devices_); - const size_t num_executors = executors_and_keys->items.size(); profiler::TraceMe activity( [&] { @@ -554,64 +553,21 @@ Status DirectSession::RunInternal( } #endif - // Use std::unique_ptr to ensure garbage collection - std::unique_ptr threadpool_wrapper; - thread::ThreadPool* pool = nullptr; - - if (run_in_caller_thread_) { - pool = nullptr; - } else if (threadpool_options.inter_op_threadpool != nullptr) { - threadpool_wrapper = absl::make_unique( - threadpool_options.inter_op_threadpool); - pool = threadpool_wrapper.get(); - } else if (run_options.inter_op_thread_pool() >= 0) { - pool = thread_pools_[run_options.inter_op_thread_pool()].first; - } - - if (pool == nullptr) { - // We allow using the caller thread only when having a single executor - // specified. - if (executors_and_keys->items.size() > 1) { - pool = thread_pools_[0].first; - } else { - VLOG(1) << "Executing Session::Run() synchronously!"; - } - } - - if (run_options.inter_op_thread_pool() < -1 || - run_options.inter_op_thread_pool() >= - static_cast(thread_pools_.size())) { - return errors::InvalidArgument("Invalid inter_op_thread_pool: ", - run_options.inter_op_thread_pool()); - } - - std::unique_ptr handler; - if (ShouldUseRunHandlerPool(run_options) && - run_options.experimental().use_run_handler_pool()) { - VLOG(1) << "Using RunHandler to scheduler inter-op closures."; - handler = GetOrCreateRunHandlerPool(options_)->Get(step_id); - } - auto* handler_ptr = handler.get(); - - Executor::Args::Runner default_runner = nullptr; - - if (pool == nullptr) { - default_runner = [](Executor::Args::Closure c) { c(); }; - } else if (handler_ptr != nullptr) { - default_runner = [handler_ptr](Executor::Args::Closure c) { - handler_ptr->ScheduleInterOpClosure(std::move(c)); - }; - } else { - default_runner = [pool](Executor::Args::Closure c) { - pool->Schedule(std::move(c)); - }; - } - // Start parallel Executors. - const int64 call_timeout = run_options.timeout_in_ms() > 0 - ? run_options.timeout_in_ms() - : operation_timeout_in_ms_; - const bool can_execute_synchronously = pool == nullptr && call_timeout == 0; + const size_t num_executors = executors_and_keys->items.size(); + Notification executors_done; + + // TODO(mrry): Switch the RunInternal() synchronous use of ExecutorBarrier + // to use a stack-allocated barrier. + ExecutorBarrier* barrier = + new ExecutorBarrier(num_executors, run_state.rendez.get(), + [&run_state, &executors_done](const Status& ret) { + { + mutex_lock l(run_state.mu); + run_state.status.Update(ret); + } + executors_done.Notify(); + }); Executor::Args args; args.step_id = step_id; @@ -655,6 +611,14 @@ Status DirectSession::RunInternal( profiler_session = ProfilerSession::Create(); } + if (run_options.inter_op_thread_pool() < -1 || + run_options.inter_op_thread_pool() >= + static_cast(thread_pools_.size())) { + delete barrier; + return errors::InvalidArgument("Invalid inter_op_thread_pool: ", + run_options.inter_op_thread_pool()); + } + // Register this step with session's cancellation manager, so that // `Session::Close()` will cancel the step. const CancellationToken cancellation_token = @@ -664,76 +628,98 @@ Status DirectSession::RunInternal( step_cancellation_manager.StartCancel(); }); if (already_cancelled) { + delete barrier; return errors::Cancelled("Run call was cancelled"); } - Status run_status; + // Use std::unique_ptr to ensure garbage collection + std::unique_ptr threadpool_wrapper; + thread::ThreadPool* pool = nullptr; - auto set_threadpool_args_for_item = - [&default_runner, &handler](const PerPartitionExecutorsAndLib& item, - Executor::Args* args) { - // TODO(azaks): support partial run. - // TODO(azaks): if the device picks its own threadpool, we need to - // assign - // less threads to the main compute pool by default. - thread::ThreadPool* device_thread_pool = - item.device->tensorflow_device_thread_pool(); - // TODO(crk): Investigate usage of RunHandlerPool when using device - // specific thread pool(s). - if (!device_thread_pool) { - args->runner = default_runner; - } else { - args->runner = [device_thread_pool](Executor::Args::Closure c) { - device_thread_pool->Schedule(std::move(c)); - }; - } - if (handler != nullptr) { - args->user_intra_op_threadpool = - handler->AsIntraThreadPoolInterface(); - } - }; + if (run_in_caller_thread_) { + pool = nullptr; + } else if (threadpool_options.inter_op_threadpool != nullptr) { + threadpool_wrapper = absl::make_unique( + threadpool_options.inter_op_threadpool); + pool = threadpool_wrapper.get(); + } else if (run_options.inter_op_thread_pool() >= 0) { + pool = thread_pools_[run_options.inter_op_thread_pool()].first; + } - if (can_execute_synchronously) { - const auto& item = executors_and_keys->items[0]; - set_threadpool_args_for_item(item, &args); - run_status = item.executor->Run(args); - } else { - // `barrier` will delete itself after the final executor finishes. - Notification executors_done; - ExecutorBarrier* barrier = - new ExecutorBarrier(num_executors, run_state.rendez.get(), - [&run_state, &executors_done](const Status& ret) { - { - mutex_lock l(run_state.mu); - run_state.status.Update(ret); - } - executors_done.Notify(); - }); - - for (const auto& item : executors_and_keys->items) { - set_threadpool_args_for_item(item, &args); - item.executor->RunAsync(args, barrier->Get()); - } - - WaitForNotification(&executors_done, &run_state, &step_cancellation_manager, - call_timeout); - { - tf_shared_lock l(run_state.mu); - run_status = run_state.status; + if (pool == nullptr) { + // We allow using the caller thread only when having a single executor + // specified. + if (executors_and_keys->items.size() > 1) { + pool = thread_pools_[0].first; + } else { + VLOG(1) << "Executing Session::Run() synchronously!"; } } + std::unique_ptr handler; + if (ShouldUseRunHandlerPool(run_options) && + run_options.experimental().use_run_handler_pool()) { + VLOG(1) << "Using RunHandler to scheduler inter-op closures."; + handler = GetOrCreateRunHandlerPool(options_)->Get(step_id); + } + auto* handler_ptr = handler.get(); + + Executor::Args::Runner default_runner = nullptr; + + if (pool == nullptr) { + default_runner = [](Executor::Args::Closure c) { c(); }; + } else if (handler_ptr != nullptr) { + default_runner = [handler_ptr](Executor::Args::Closure c) { + handler_ptr->ScheduleInterOpClosure(std::move(c)); + }; + } else { + default_runner = [this, pool](Executor::Args::Closure c) { + pool->Schedule(std::move(c)); + }; + } + + for (const auto& item : executors_and_keys->items) { + // TODO(azaks): support partial run. + // TODO(azaks): if the device picks its own threadpool, we need to assign + // less threads to the main compute pool by default. + thread::ThreadPool* device_thread_pool = + item.device->tensorflow_device_thread_pool(); + // TODO(crk): Investigate usage of RunHandlerPool when using device specific + // thread pool(s). + if (!device_thread_pool) { + args.runner = default_runner; + } else { + args.runner = [this, device_thread_pool](Executor::Args::Closure c) { + device_thread_pool->Schedule(std::move(c)); + }; + } + if (handler != nullptr) { + args.user_intra_op_threadpool = handler->AsIntraThreadPoolInterface(); + } + + item.executor->RunAsync(args, barrier->Get()); + } + + WaitForNotification(&executors_done, &run_state, &step_cancellation_manager, + run_options.timeout_in_ms() > 0 + ? run_options.timeout_in_ms() + : operation_timeout_in_ms_); + if (!cancellation_manager_->DeregisterCallback(cancellation_token)) { // The step has been cancelled: make sure we don't attempt to receive the // outputs as this would make it block forever. - run_status.Update(errors::Cancelled("Run call was cancelled")); + mutex_lock l(run_state.mu); + run_state.status.Update(errors::Cancelled("Run call was cancelled")); } if (profiler_session) { TF_RETURN_IF_ERROR(profiler_session->CollectData(run_metadata)); } - TF_RETURN_IF_ERROR(run_status); + { + mutex_lock l(run_state.mu); + TF_RETURN_IF_ERROR(run_state.status); + } // Save the output tensors of this run we choose to keep. if (!run_state.tensor_store.empty()) { diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 42d5b9eab4f..ad85e712d91 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -111,7 +111,7 @@ class Executor { virtual void RunAsync(const Args& args, DoneCallback done) = 0; // Synchronous wrapper for RunAsync(). - virtual Status Run(const Args& args) { + Status Run(const Args& args) { Status ret; Notification n; RunAsync(args, [&ret, &n](const Status& s) { diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc index a6b31679fa6..cd47da0e653 100644 --- a/tensorflow/core/kernels/data/single_threaded_executor.cc +++ b/tensorflow/core/kernels/data/single_threaded_executor.cc @@ -195,7 +195,10 @@ class SingleThreadedExecutorImpl : public Executor { return Status::OK(); } - Status Run(const Args& args) override { + // TODO(mrry): Consider specializing the implementation of Executor::Run() + // instead, to avoid unnecessary atomic operations in the callback when + // running synchronously. + void RunAsync(const Args& args, DoneCallback done) override { // The inputs to each kernel are stored contiguously in `inputs`. // // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to @@ -272,9 +275,9 @@ class SingleThreadedExecutorImpl : public Executor { const size_t received_args = args.call_frame ? args.call_frame->num_args() : 0; if (arg_output_locations_.size() > received_args) { - return errors::InvalidArgument("Expected ", arg_output_locations_.size(), - " arguments, but only received ", - received_args, "."); + done(errors::InvalidArgument("Expected ", arg_output_locations_.size(), + " arguments, but only received ", + received_args, ".")); } // ArgOp is a relatively expensive OpKernel due to the Tensor @@ -348,7 +351,8 @@ class SingleThreadedExecutorImpl : public Executor { } } } - return ctx.status(); + done(ctx.status()); + return; } // Free the inputs to the current kernel. @@ -375,11 +379,7 @@ class SingleThreadedExecutorImpl : public Executor { delete val.tensor; } } - return Status::OK(); - } - - void RunAsync(const Args& args, DoneCallback done) override { - done(Run(args)); + done(Status::OK()); } private: