Automated g4 rollback of changelist 281562211.
PiperOrigin-RevId: 281607574 Change-Id: I7d06715ba20f527db4c2f5bea61a51e04b7816a1
This commit is contained in:
parent
7286a69c3c
commit
98ab062a8b
@ -497,7 +497,6 @@ Status DirectSession::RunInternal(
|
|||||||
const uint64 start_time_usecs = options_.env->NowMicros();
|
const uint64 start_time_usecs = options_.env->NowMicros();
|
||||||
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
|
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
|
||||||
RunState run_state(step_id, &devices_);
|
RunState run_state(step_id, &devices_);
|
||||||
const size_t num_executors = executors_and_keys->items.size();
|
|
||||||
|
|
||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
[&] {
|
[&] {
|
||||||
@ -554,64 +553,21 @@ Status DirectSession::RunInternal(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Use std::unique_ptr to ensure garbage collection
|
|
||||||
std::unique_ptr<thread::ThreadPool> threadpool_wrapper;
|
|
||||||
thread::ThreadPool* pool = nullptr;
|
|
||||||
|
|
||||||
if (run_in_caller_thread_) {
|
|
||||||
pool = nullptr;
|
|
||||||
} else if (threadpool_options.inter_op_threadpool != nullptr) {
|
|
||||||
threadpool_wrapper = absl::make_unique<thread::ThreadPool>(
|
|
||||||
threadpool_options.inter_op_threadpool);
|
|
||||||
pool = threadpool_wrapper.get();
|
|
||||||
} else if (run_options.inter_op_thread_pool() >= 0) {
|
|
||||||
pool = thread_pools_[run_options.inter_op_thread_pool()].first;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pool == nullptr) {
|
|
||||||
// We allow using the caller thread only when having a single executor
|
|
||||||
// specified.
|
|
||||||
if (executors_and_keys->items.size() > 1) {
|
|
||||||
pool = thread_pools_[0].first;
|
|
||||||
} else {
|
|
||||||
VLOG(1) << "Executing Session::Run() synchronously!";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (run_options.inter_op_thread_pool() < -1 ||
|
|
||||||
run_options.inter_op_thread_pool() >=
|
|
||||||
static_cast<int32>(thread_pools_.size())) {
|
|
||||||
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
|
|
||||||
run_options.inter_op_thread_pool());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<RunHandler> handler;
|
|
||||||
if (ShouldUseRunHandlerPool(run_options) &&
|
|
||||||
run_options.experimental().use_run_handler_pool()) {
|
|
||||||
VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
|
|
||||||
handler = GetOrCreateRunHandlerPool(options_)->Get(step_id);
|
|
||||||
}
|
|
||||||
auto* handler_ptr = handler.get();
|
|
||||||
|
|
||||||
Executor::Args::Runner default_runner = nullptr;
|
|
||||||
|
|
||||||
if (pool == nullptr) {
|
|
||||||
default_runner = [](Executor::Args::Closure c) { c(); };
|
|
||||||
} else if (handler_ptr != nullptr) {
|
|
||||||
default_runner = [handler_ptr](Executor::Args::Closure c) {
|
|
||||||
handler_ptr->ScheduleInterOpClosure(std::move(c));
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
default_runner = [pool](Executor::Args::Closure c) {
|
|
||||||
pool->Schedule(std::move(c));
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start parallel Executors.
|
// Start parallel Executors.
|
||||||
const int64 call_timeout = run_options.timeout_in_ms() > 0
|
const size_t num_executors = executors_and_keys->items.size();
|
||||||
? run_options.timeout_in_ms()
|
Notification executors_done;
|
||||||
: operation_timeout_in_ms_;
|
|
||||||
const bool can_execute_synchronously = pool == nullptr && call_timeout == 0;
|
// TODO(mrry): Switch the RunInternal() synchronous use of ExecutorBarrier
|
||||||
|
// to use a stack-allocated barrier.
|
||||||
|
ExecutorBarrier* barrier =
|
||||||
|
new ExecutorBarrier(num_executors, run_state.rendez.get(),
|
||||||
|
[&run_state, &executors_done](const Status& ret) {
|
||||||
|
{
|
||||||
|
mutex_lock l(run_state.mu);
|
||||||
|
run_state.status.Update(ret);
|
||||||
|
}
|
||||||
|
executors_done.Notify();
|
||||||
|
});
|
||||||
|
|
||||||
Executor::Args args;
|
Executor::Args args;
|
||||||
args.step_id = step_id;
|
args.step_id = step_id;
|
||||||
@ -655,6 +611,14 @@ Status DirectSession::RunInternal(
|
|||||||
profiler_session = ProfilerSession::Create();
|
profiler_session = ProfilerSession::Create();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (run_options.inter_op_thread_pool() < -1 ||
|
||||||
|
run_options.inter_op_thread_pool() >=
|
||||||
|
static_cast<int32>(thread_pools_.size())) {
|
||||||
|
delete barrier;
|
||||||
|
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
|
||||||
|
run_options.inter_op_thread_pool());
|
||||||
|
}
|
||||||
|
|
||||||
// Register this step with session's cancellation manager, so that
|
// Register this step with session's cancellation manager, so that
|
||||||
// `Session::Close()` will cancel the step.
|
// `Session::Close()` will cancel the step.
|
||||||
const CancellationToken cancellation_token =
|
const CancellationToken cancellation_token =
|
||||||
@ -664,76 +628,98 @@ Status DirectSession::RunInternal(
|
|||||||
step_cancellation_manager.StartCancel();
|
step_cancellation_manager.StartCancel();
|
||||||
});
|
});
|
||||||
if (already_cancelled) {
|
if (already_cancelled) {
|
||||||
|
delete barrier;
|
||||||
return errors::Cancelled("Run call was cancelled");
|
return errors::Cancelled("Run call was cancelled");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status run_status;
|
// Use std::unique_ptr to ensure garbage collection
|
||||||
|
std::unique_ptr<thread::ThreadPool> threadpool_wrapper;
|
||||||
|
thread::ThreadPool* pool = nullptr;
|
||||||
|
|
||||||
auto set_threadpool_args_for_item =
|
if (run_in_caller_thread_) {
|
||||||
[&default_runner, &handler](const PerPartitionExecutorsAndLib& item,
|
pool = nullptr;
|
||||||
Executor::Args* args) {
|
} else if (threadpool_options.inter_op_threadpool != nullptr) {
|
||||||
// TODO(azaks): support partial run.
|
threadpool_wrapper = absl::make_unique<thread::ThreadPool>(
|
||||||
// TODO(azaks): if the device picks its own threadpool, we need to
|
threadpool_options.inter_op_threadpool);
|
||||||
// assign
|
pool = threadpool_wrapper.get();
|
||||||
// less threads to the main compute pool by default.
|
} else if (run_options.inter_op_thread_pool() >= 0) {
|
||||||
thread::ThreadPool* device_thread_pool =
|
pool = thread_pools_[run_options.inter_op_thread_pool()].first;
|
||||||
item.device->tensorflow_device_thread_pool();
|
}
|
||||||
// TODO(crk): Investigate usage of RunHandlerPool when using device
|
|
||||||
// specific thread pool(s).
|
|
||||||
if (!device_thread_pool) {
|
|
||||||
args->runner = default_runner;
|
|
||||||
} else {
|
|
||||||
args->runner = [device_thread_pool](Executor::Args::Closure c) {
|
|
||||||
device_thread_pool->Schedule(std::move(c));
|
|
||||||
};
|
|
||||||
}
|
|
||||||
if (handler != nullptr) {
|
|
||||||
args->user_intra_op_threadpool =
|
|
||||||
handler->AsIntraThreadPoolInterface();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if (can_execute_synchronously) {
|
if (pool == nullptr) {
|
||||||
const auto& item = executors_and_keys->items[0];
|
// We allow using the caller thread only when having a single executor
|
||||||
set_threadpool_args_for_item(item, &args);
|
// specified.
|
||||||
run_status = item.executor->Run(args);
|
if (executors_and_keys->items.size() > 1) {
|
||||||
} else {
|
pool = thread_pools_[0].first;
|
||||||
// `barrier` will delete itself after the final executor finishes.
|
} else {
|
||||||
Notification executors_done;
|
VLOG(1) << "Executing Session::Run() synchronously!";
|
||||||
ExecutorBarrier* barrier =
|
|
||||||
new ExecutorBarrier(num_executors, run_state.rendez.get(),
|
|
||||||
[&run_state, &executors_done](const Status& ret) {
|
|
||||||
{
|
|
||||||
mutex_lock l(run_state.mu);
|
|
||||||
run_state.status.Update(ret);
|
|
||||||
}
|
|
||||||
executors_done.Notify();
|
|
||||||
});
|
|
||||||
|
|
||||||
for (const auto& item : executors_and_keys->items) {
|
|
||||||
set_threadpool_args_for_item(item, &args);
|
|
||||||
item.executor->RunAsync(args, barrier->Get());
|
|
||||||
}
|
|
||||||
|
|
||||||
WaitForNotification(&executors_done, &run_state, &step_cancellation_manager,
|
|
||||||
call_timeout);
|
|
||||||
{
|
|
||||||
tf_shared_lock l(run_state.mu);
|
|
||||||
run_status = run_state.status;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<RunHandler> handler;
|
||||||
|
if (ShouldUseRunHandlerPool(run_options) &&
|
||||||
|
run_options.experimental().use_run_handler_pool()) {
|
||||||
|
VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
|
||||||
|
handler = GetOrCreateRunHandlerPool(options_)->Get(step_id);
|
||||||
|
}
|
||||||
|
auto* handler_ptr = handler.get();
|
||||||
|
|
||||||
|
Executor::Args::Runner default_runner = nullptr;
|
||||||
|
|
||||||
|
if (pool == nullptr) {
|
||||||
|
default_runner = [](Executor::Args::Closure c) { c(); };
|
||||||
|
} else if (handler_ptr != nullptr) {
|
||||||
|
default_runner = [handler_ptr](Executor::Args::Closure c) {
|
||||||
|
handler_ptr->ScheduleInterOpClosure(std::move(c));
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
default_runner = [this, pool](Executor::Args::Closure c) {
|
||||||
|
pool->Schedule(std::move(c));
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& item : executors_and_keys->items) {
|
||||||
|
// TODO(azaks): support partial run.
|
||||||
|
// TODO(azaks): if the device picks its own threadpool, we need to assign
|
||||||
|
// less threads to the main compute pool by default.
|
||||||
|
thread::ThreadPool* device_thread_pool =
|
||||||
|
item.device->tensorflow_device_thread_pool();
|
||||||
|
// TODO(crk): Investigate usage of RunHandlerPool when using device specific
|
||||||
|
// thread pool(s).
|
||||||
|
if (!device_thread_pool) {
|
||||||
|
args.runner = default_runner;
|
||||||
|
} else {
|
||||||
|
args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
|
||||||
|
device_thread_pool->Schedule(std::move(c));
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (handler != nullptr) {
|
||||||
|
args.user_intra_op_threadpool = handler->AsIntraThreadPoolInterface();
|
||||||
|
}
|
||||||
|
|
||||||
|
item.executor->RunAsync(args, barrier->Get());
|
||||||
|
}
|
||||||
|
|
||||||
|
WaitForNotification(&executors_done, &run_state, &step_cancellation_manager,
|
||||||
|
run_options.timeout_in_ms() > 0
|
||||||
|
? run_options.timeout_in_ms()
|
||||||
|
: operation_timeout_in_ms_);
|
||||||
|
|
||||||
if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
|
if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
|
||||||
// The step has been cancelled: make sure we don't attempt to receive the
|
// The step has been cancelled: make sure we don't attempt to receive the
|
||||||
// outputs as this would make it block forever.
|
// outputs as this would make it block forever.
|
||||||
run_status.Update(errors::Cancelled("Run call was cancelled"));
|
mutex_lock l(run_state.mu);
|
||||||
|
run_state.status.Update(errors::Cancelled("Run call was cancelled"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (profiler_session) {
|
if (profiler_session) {
|
||||||
TF_RETURN_IF_ERROR(profiler_session->CollectData(run_metadata));
|
TF_RETURN_IF_ERROR(profiler_session->CollectData(run_metadata));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(run_status);
|
{
|
||||||
|
mutex_lock l(run_state.mu);
|
||||||
|
TF_RETURN_IF_ERROR(run_state.status);
|
||||||
|
}
|
||||||
|
|
||||||
// Save the output tensors of this run we choose to keep.
|
// Save the output tensors of this run we choose to keep.
|
||||||
if (!run_state.tensor_store.empty()) {
|
if (!run_state.tensor_store.empty()) {
|
||||||
|
@ -111,7 +111,7 @@ class Executor {
|
|||||||
virtual void RunAsync(const Args& args, DoneCallback done) = 0;
|
virtual void RunAsync(const Args& args, DoneCallback done) = 0;
|
||||||
|
|
||||||
// Synchronous wrapper for RunAsync().
|
// Synchronous wrapper for RunAsync().
|
||||||
virtual Status Run(const Args& args) {
|
Status Run(const Args& args) {
|
||||||
Status ret;
|
Status ret;
|
||||||
Notification n;
|
Notification n;
|
||||||
RunAsync(args, [&ret, &n](const Status& s) {
|
RunAsync(args, [&ret, &n](const Status& s) {
|
||||||
|
@ -195,7 +195,10 @@ class SingleThreadedExecutorImpl : public Executor {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Run(const Args& args) override {
|
// TODO(mrry): Consider specializing the implementation of Executor::Run()
|
||||||
|
// instead, to avoid unnecessary atomic operations in the callback when
|
||||||
|
// running synchronously.
|
||||||
|
void RunAsync(const Args& args, DoneCallback done) override {
|
||||||
// The inputs to each kernel are stored contiguously in `inputs`.
|
// The inputs to each kernel are stored contiguously in `inputs`.
|
||||||
//
|
//
|
||||||
// We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to
|
// We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to
|
||||||
@ -272,9 +275,9 @@ class SingleThreadedExecutorImpl : public Executor {
|
|||||||
const size_t received_args =
|
const size_t received_args =
|
||||||
args.call_frame ? args.call_frame->num_args() : 0;
|
args.call_frame ? args.call_frame->num_args() : 0;
|
||||||
if (arg_output_locations_.size() > received_args) {
|
if (arg_output_locations_.size() > received_args) {
|
||||||
return errors::InvalidArgument("Expected ", arg_output_locations_.size(),
|
done(errors::InvalidArgument("Expected ", arg_output_locations_.size(),
|
||||||
" arguments, but only received ",
|
" arguments, but only received ",
|
||||||
received_args, ".");
|
received_args, "."));
|
||||||
}
|
}
|
||||||
|
|
||||||
// ArgOp is a relatively expensive OpKernel due to the Tensor
|
// ArgOp is a relatively expensive OpKernel due to the Tensor
|
||||||
@ -348,7 +351,8 @@ class SingleThreadedExecutorImpl : public Executor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ctx.status();
|
done(ctx.status());
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Free the inputs to the current kernel.
|
// Free the inputs to the current kernel.
|
||||||
@ -375,11 +379,7 @@ class SingleThreadedExecutorImpl : public Executor {
|
|||||||
delete val.tensor;
|
delete val.tensor;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
done(Status::OK());
|
||||||
}
|
|
||||||
|
|
||||||
void RunAsync(const Args& args, DoneCallback done) override {
|
|
||||||
done(Run(args));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user