Automated g4 rollback of changelist 281562211.
PiperOrigin-RevId: 281607574 Change-Id: I7d06715ba20f527db4c2f5bea61a51e04b7816a1
This commit is contained in:
parent
7286a69c3c
commit
98ab062a8b
@ -497,7 +497,6 @@ Status DirectSession::RunInternal(
|
||||
const uint64 start_time_usecs = options_.env->NowMicros();
|
||||
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
|
||||
RunState run_state(step_id, &devices_);
|
||||
const size_t num_executors = executors_and_keys->items.size();
|
||||
|
||||
profiler::TraceMe activity(
|
||||
[&] {
|
||||
@ -554,64 +553,21 @@ Status DirectSession::RunInternal(
|
||||
}
|
||||
#endif
|
||||
|
||||
// Use std::unique_ptr to ensure garbage collection
|
||||
std::unique_ptr<thread::ThreadPool> threadpool_wrapper;
|
||||
thread::ThreadPool* pool = nullptr;
|
||||
|
||||
if (run_in_caller_thread_) {
|
||||
pool = nullptr;
|
||||
} else if (threadpool_options.inter_op_threadpool != nullptr) {
|
||||
threadpool_wrapper = absl::make_unique<thread::ThreadPool>(
|
||||
threadpool_options.inter_op_threadpool);
|
||||
pool = threadpool_wrapper.get();
|
||||
} else if (run_options.inter_op_thread_pool() >= 0) {
|
||||
pool = thread_pools_[run_options.inter_op_thread_pool()].first;
|
||||
}
|
||||
|
||||
if (pool == nullptr) {
|
||||
// We allow using the caller thread only when having a single executor
|
||||
// specified.
|
||||
if (executors_and_keys->items.size() > 1) {
|
||||
pool = thread_pools_[0].first;
|
||||
} else {
|
||||
VLOG(1) << "Executing Session::Run() synchronously!";
|
||||
}
|
||||
}
|
||||
|
||||
if (run_options.inter_op_thread_pool() < -1 ||
|
||||
run_options.inter_op_thread_pool() >=
|
||||
static_cast<int32>(thread_pools_.size())) {
|
||||
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
|
||||
run_options.inter_op_thread_pool());
|
||||
}
|
||||
|
||||
std::unique_ptr<RunHandler> handler;
|
||||
if (ShouldUseRunHandlerPool(run_options) &&
|
||||
run_options.experimental().use_run_handler_pool()) {
|
||||
VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
|
||||
handler = GetOrCreateRunHandlerPool(options_)->Get(step_id);
|
||||
}
|
||||
auto* handler_ptr = handler.get();
|
||||
|
||||
Executor::Args::Runner default_runner = nullptr;
|
||||
|
||||
if (pool == nullptr) {
|
||||
default_runner = [](Executor::Args::Closure c) { c(); };
|
||||
} else if (handler_ptr != nullptr) {
|
||||
default_runner = [handler_ptr](Executor::Args::Closure c) {
|
||||
handler_ptr->ScheduleInterOpClosure(std::move(c));
|
||||
};
|
||||
} else {
|
||||
default_runner = [pool](Executor::Args::Closure c) {
|
||||
pool->Schedule(std::move(c));
|
||||
};
|
||||
}
|
||||
|
||||
// Start parallel Executors.
|
||||
const int64 call_timeout = run_options.timeout_in_ms() > 0
|
||||
? run_options.timeout_in_ms()
|
||||
: operation_timeout_in_ms_;
|
||||
const bool can_execute_synchronously = pool == nullptr && call_timeout == 0;
|
||||
const size_t num_executors = executors_and_keys->items.size();
|
||||
Notification executors_done;
|
||||
|
||||
// TODO(mrry): Switch the RunInternal() synchronous use of ExecutorBarrier
|
||||
// to use a stack-allocated barrier.
|
||||
ExecutorBarrier* barrier =
|
||||
new ExecutorBarrier(num_executors, run_state.rendez.get(),
|
||||
[&run_state, &executors_done](const Status& ret) {
|
||||
{
|
||||
mutex_lock l(run_state.mu);
|
||||
run_state.status.Update(ret);
|
||||
}
|
||||
executors_done.Notify();
|
||||
});
|
||||
|
||||
Executor::Args args;
|
||||
args.step_id = step_id;
|
||||
@ -655,6 +611,14 @@ Status DirectSession::RunInternal(
|
||||
profiler_session = ProfilerSession::Create();
|
||||
}
|
||||
|
||||
if (run_options.inter_op_thread_pool() < -1 ||
|
||||
run_options.inter_op_thread_pool() >=
|
||||
static_cast<int32>(thread_pools_.size())) {
|
||||
delete barrier;
|
||||
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
|
||||
run_options.inter_op_thread_pool());
|
||||
}
|
||||
|
||||
// Register this step with session's cancellation manager, so that
|
||||
// `Session::Close()` will cancel the step.
|
||||
const CancellationToken cancellation_token =
|
||||
@ -664,76 +628,98 @@ Status DirectSession::RunInternal(
|
||||
step_cancellation_manager.StartCancel();
|
||||
});
|
||||
if (already_cancelled) {
|
||||
delete barrier;
|
||||
return errors::Cancelled("Run call was cancelled");
|
||||
}
|
||||
|
||||
Status run_status;
|
||||
// Use std::unique_ptr to ensure garbage collection
|
||||
std::unique_ptr<thread::ThreadPool> threadpool_wrapper;
|
||||
thread::ThreadPool* pool = nullptr;
|
||||
|
||||
auto set_threadpool_args_for_item =
|
||||
[&default_runner, &handler](const PerPartitionExecutorsAndLib& item,
|
||||
Executor::Args* args) {
|
||||
// TODO(azaks): support partial run.
|
||||
// TODO(azaks): if the device picks its own threadpool, we need to
|
||||
// assign
|
||||
// less threads to the main compute pool by default.
|
||||
thread::ThreadPool* device_thread_pool =
|
||||
item.device->tensorflow_device_thread_pool();
|
||||
// TODO(crk): Investigate usage of RunHandlerPool when using device
|
||||
// specific thread pool(s).
|
||||
if (!device_thread_pool) {
|
||||
args->runner = default_runner;
|
||||
} else {
|
||||
args->runner = [device_thread_pool](Executor::Args::Closure c) {
|
||||
device_thread_pool->Schedule(std::move(c));
|
||||
};
|
||||
}
|
||||
if (handler != nullptr) {
|
||||
args->user_intra_op_threadpool =
|
||||
handler->AsIntraThreadPoolInterface();
|
||||
}
|
||||
};
|
||||
if (run_in_caller_thread_) {
|
||||
pool = nullptr;
|
||||
} else if (threadpool_options.inter_op_threadpool != nullptr) {
|
||||
threadpool_wrapper = absl::make_unique<thread::ThreadPool>(
|
||||
threadpool_options.inter_op_threadpool);
|
||||
pool = threadpool_wrapper.get();
|
||||
} else if (run_options.inter_op_thread_pool() >= 0) {
|
||||
pool = thread_pools_[run_options.inter_op_thread_pool()].first;
|
||||
}
|
||||
|
||||
if (can_execute_synchronously) {
|
||||
const auto& item = executors_and_keys->items[0];
|
||||
set_threadpool_args_for_item(item, &args);
|
||||
run_status = item.executor->Run(args);
|
||||
} else {
|
||||
// `barrier` will delete itself after the final executor finishes.
|
||||
Notification executors_done;
|
||||
ExecutorBarrier* barrier =
|
||||
new ExecutorBarrier(num_executors, run_state.rendez.get(),
|
||||
[&run_state, &executors_done](const Status& ret) {
|
||||
{
|
||||
mutex_lock l(run_state.mu);
|
||||
run_state.status.Update(ret);
|
||||
}
|
||||
executors_done.Notify();
|
||||
});
|
||||
|
||||
for (const auto& item : executors_and_keys->items) {
|
||||
set_threadpool_args_for_item(item, &args);
|
||||
item.executor->RunAsync(args, barrier->Get());
|
||||
}
|
||||
|
||||
WaitForNotification(&executors_done, &run_state, &step_cancellation_manager,
|
||||
call_timeout);
|
||||
{
|
||||
tf_shared_lock l(run_state.mu);
|
||||
run_status = run_state.status;
|
||||
if (pool == nullptr) {
|
||||
// We allow using the caller thread only when having a single executor
|
||||
// specified.
|
||||
if (executors_and_keys->items.size() > 1) {
|
||||
pool = thread_pools_[0].first;
|
||||
} else {
|
||||
VLOG(1) << "Executing Session::Run() synchronously!";
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<RunHandler> handler;
|
||||
if (ShouldUseRunHandlerPool(run_options) &&
|
||||
run_options.experimental().use_run_handler_pool()) {
|
||||
VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
|
||||
handler = GetOrCreateRunHandlerPool(options_)->Get(step_id);
|
||||
}
|
||||
auto* handler_ptr = handler.get();
|
||||
|
||||
Executor::Args::Runner default_runner = nullptr;
|
||||
|
||||
if (pool == nullptr) {
|
||||
default_runner = [](Executor::Args::Closure c) { c(); };
|
||||
} else if (handler_ptr != nullptr) {
|
||||
default_runner = [handler_ptr](Executor::Args::Closure c) {
|
||||
handler_ptr->ScheduleInterOpClosure(std::move(c));
|
||||
};
|
||||
} else {
|
||||
default_runner = [this, pool](Executor::Args::Closure c) {
|
||||
pool->Schedule(std::move(c));
|
||||
};
|
||||
}
|
||||
|
||||
for (const auto& item : executors_and_keys->items) {
|
||||
// TODO(azaks): support partial run.
|
||||
// TODO(azaks): if the device picks its own threadpool, we need to assign
|
||||
// less threads to the main compute pool by default.
|
||||
thread::ThreadPool* device_thread_pool =
|
||||
item.device->tensorflow_device_thread_pool();
|
||||
// TODO(crk): Investigate usage of RunHandlerPool when using device specific
|
||||
// thread pool(s).
|
||||
if (!device_thread_pool) {
|
||||
args.runner = default_runner;
|
||||
} else {
|
||||
args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
|
||||
device_thread_pool->Schedule(std::move(c));
|
||||
};
|
||||
}
|
||||
if (handler != nullptr) {
|
||||
args.user_intra_op_threadpool = handler->AsIntraThreadPoolInterface();
|
||||
}
|
||||
|
||||
item.executor->RunAsync(args, barrier->Get());
|
||||
}
|
||||
|
||||
WaitForNotification(&executors_done, &run_state, &step_cancellation_manager,
|
||||
run_options.timeout_in_ms() > 0
|
||||
? run_options.timeout_in_ms()
|
||||
: operation_timeout_in_ms_);
|
||||
|
||||
if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
|
||||
// The step has been cancelled: make sure we don't attempt to receive the
|
||||
// outputs as this would make it block forever.
|
||||
run_status.Update(errors::Cancelled("Run call was cancelled"));
|
||||
mutex_lock l(run_state.mu);
|
||||
run_state.status.Update(errors::Cancelled("Run call was cancelled"));
|
||||
}
|
||||
|
||||
if (profiler_session) {
|
||||
TF_RETURN_IF_ERROR(profiler_session->CollectData(run_metadata));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(run_status);
|
||||
{
|
||||
mutex_lock l(run_state.mu);
|
||||
TF_RETURN_IF_ERROR(run_state.status);
|
||||
}
|
||||
|
||||
// Save the output tensors of this run we choose to keep.
|
||||
if (!run_state.tensor_store.empty()) {
|
||||
|
@ -111,7 +111,7 @@ class Executor {
|
||||
virtual void RunAsync(const Args& args, DoneCallback done) = 0;
|
||||
|
||||
// Synchronous wrapper for RunAsync().
|
||||
virtual Status Run(const Args& args) {
|
||||
Status Run(const Args& args) {
|
||||
Status ret;
|
||||
Notification n;
|
||||
RunAsync(args, [&ret, &n](const Status& s) {
|
||||
|
@ -195,7 +195,10 @@ class SingleThreadedExecutorImpl : public Executor {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Run(const Args& args) override {
|
||||
// TODO(mrry): Consider specializing the implementation of Executor::Run()
|
||||
// instead, to avoid unnecessary atomic operations in the callback when
|
||||
// running synchronously.
|
||||
void RunAsync(const Args& args, DoneCallback done) override {
|
||||
// The inputs to each kernel are stored contiguously in `inputs`.
|
||||
//
|
||||
// We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to
|
||||
@ -272,9 +275,9 @@ class SingleThreadedExecutorImpl : public Executor {
|
||||
const size_t received_args =
|
||||
args.call_frame ? args.call_frame->num_args() : 0;
|
||||
if (arg_output_locations_.size() > received_args) {
|
||||
return errors::InvalidArgument("Expected ", arg_output_locations_.size(),
|
||||
" arguments, but only received ",
|
||||
received_args, ".");
|
||||
done(errors::InvalidArgument("Expected ", arg_output_locations_.size(),
|
||||
" arguments, but only received ",
|
||||
received_args, "."));
|
||||
}
|
||||
|
||||
// ArgOp is a relatively expensive OpKernel due to the Tensor
|
||||
@ -348,7 +351,8 @@ class SingleThreadedExecutorImpl : public Executor {
|
||||
}
|
||||
}
|
||||
}
|
||||
return ctx.status();
|
||||
done(ctx.status());
|
||||
return;
|
||||
}
|
||||
|
||||
// Free the inputs to the current kernel.
|
||||
@ -375,11 +379,7 @@ class SingleThreadedExecutorImpl : public Executor {
|
||||
delete val.tensor;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RunAsync(const Args& args, DoneCallback done) override {
|
||||
done(Run(args));
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user