Add experimental Session::MakeCallable() API and implement it for DirectSession.
The intent of this new API matches the Python `tf.Session.make_callable()` method: it splits the two roles of the `Session::Run()` method into separate methods: 1. `Session::MakeCallable()` takes information about a subgraph (such as the names of nodes to feed and fetch), and prunes and optimizes that graph, returning a simple handle. 2. `Session::RunCallable()` takes that handle, plus any values to be fed, and executes the graph, returning whatever outputs are produced. This split moves string processing off the critical path of running a step. We also add a new method `Session::ReleaseCallable()` that makes it possible to free the resources associated with a cached subgraph, and could be useful for seldom-executed graphs such as initializers. PiperOrigin-RevId: 188566635
This commit is contained in:
parent
05aa4e58c8
commit
2426308fa5
@ -318,6 +318,7 @@ DirectSession::~DirectSession() {
|
||||
for (auto& it : executors_) {
|
||||
it.second.reset();
|
||||
}
|
||||
callables_.clear();
|
||||
for (auto d : device_mgr_->ListDevices()) {
|
||||
d->op_segment()->RemoveHold(session_handle_);
|
||||
}
|
||||
@ -409,16 +410,21 @@ Status DirectSession::Run(const NamedTensorList& inputs,
|
||||
}
|
||||
|
||||
Status DirectSession::CreateDebuggerState(
|
||||
const DebugOptions& debug_options, int64 session_run_index,
|
||||
int64 executor_step_index, const std::vector<string>& input_names,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_names,
|
||||
const CallableOptions& callable_options, int64 global_step,
|
||||
int64 session_run_index, int64 executor_step_index,
|
||||
std::unique_ptr<DebuggerStateInterface>* debugger_state) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
DebuggerStateRegistry::CreateState(debug_options, debugger_state));
|
||||
TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
|
||||
callable_options.run_options().debug_options(), debugger_state));
|
||||
std::vector<string> input_names(callable_options.feed().begin(),
|
||||
callable_options.feed().end());
|
||||
std::vector<string> output_names(callable_options.fetch().begin(),
|
||||
callable_options.fetch().end());
|
||||
std::vector<string> target_names(callable_options.target().begin(),
|
||||
callable_options.target().end());
|
||||
|
||||
TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
|
||||
debug_options.global_step(), session_run_index, executor_step_index,
|
||||
input_names, output_names, target_names));
|
||||
global_step, session_run_index, executor_step_index, input_names,
|
||||
output_names, target_names));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -433,84 +439,23 @@ Status DirectSession::DecorateAndPublishGraphForDebug(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DirectSession::Run(const RunOptions& run_options,
|
||||
const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
RunMetadata* run_metadata) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
direct_session_runs->GetCell()->IncrementBy(1);
|
||||
{
|
||||
mutex_lock l(graph_def_lock_);
|
||||
if (!graph_created_) {
|
||||
return errors::InvalidArgument(
|
||||
"Session was not created with a graph before Run()!");
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the inputs names for this run of the session.
|
||||
std::vector<string> input_tensor_names;
|
||||
input_tensor_names.reserve(inputs.size());
|
||||
for (const auto& it : inputs) {
|
||||
input_tensor_names.push_back(it.first);
|
||||
}
|
||||
|
||||
if (run_options.inter_op_thread_pool() < 0 ||
|
||||
run_options.inter_op_thread_pool() >= thread_pools_.size()) {
|
||||
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
|
||||
run_options.inter_op_thread_pool());
|
||||
}
|
||||
thread::ThreadPool* pool =
|
||||
thread_pools_[run_options.inter_op_thread_pool()].first;
|
||||
|
||||
// Check if we already have an executor for these arguments.
|
||||
ExecutorsAndKeys* executors_and_keys;
|
||||
RunStateArgs run_state_args(run_options.debug_options());
|
||||
|
||||
Executor::Args args;
|
||||
args.step_id = step_id_counter_.fetch_add(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
|
||||
target_nodes, &executors_and_keys,
|
||||
&run_state_args));
|
||||
Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
|
||||
CallFrameInterface* call_frame,
|
||||
ExecutorsAndKeys* executors_and_keys,
|
||||
RunMetadata* run_metadata) {
|
||||
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
|
||||
|
||||
std::unique_ptr<DebuggerStateInterface> debugger_state;
|
||||
if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
|
||||
TF_RETURN_IF_ERROR(CreateDebuggerState(
|
||||
run_options.debug_options(), args.step_id, executor_step_count,
|
||||
input_tensor_names, output_names, target_nodes, &debugger_state));
|
||||
}
|
||||
|
||||
// Configure a call frame for the step, which we use to feed and
|
||||
// fetch values to and from the executors.
|
||||
FunctionCallFrame call_frame(executors_and_keys->input_types,
|
||||
executors_and_keys->output_types);
|
||||
gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
|
||||
for (const auto& it : inputs) {
|
||||
if (it.second.dtype() == DT_RESOURCE) {
|
||||
Tensor tensor_from_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ResourceHandleToInputTensor(it.second, &tensor_from_handle));
|
||||
feed_args[executors_and_keys->input_name_to_index[it.first]] =
|
||||
tensor_from_handle;
|
||||
} else {
|
||||
feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
|
||||
}
|
||||
}
|
||||
const Status s = call_frame.SetArgs(feed_args);
|
||||
if (errors::IsInternal(s)) {
|
||||
return errors::InvalidArgument(s.error_message());
|
||||
} else if (!s.ok()) {
|
||||
return s;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateDebuggerState(executors_and_keys->callable_options,
|
||||
run_options.debug_options().global_step(), step_id,
|
||||
executor_step_count, &debugger_state));
|
||||
}
|
||||
|
||||
// Create a run state and start execution.
|
||||
RunState run_state(args.step_id, &devices_);
|
||||
RunState run_state(step_id, &devices_);
|
||||
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
|
||||
CancellationManager step_cancellation_manager;
|
||||
args.call_frame = &call_frame;
|
||||
|
||||
// Start parallel Executors.
|
||||
const size_t num_executors = executors_and_keys->items.size();
|
||||
@ -523,15 +468,15 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
run_state.executors_done.Notify();
|
||||
});
|
||||
|
||||
Executor::Args args;
|
||||
args.step_id = step_id;
|
||||
args.call_frame = call_frame;
|
||||
args.rendezvous = run_state.rendez;
|
||||
CancellationManager step_cancellation_manager;
|
||||
args.cancellation_manager = &step_cancellation_manager;
|
||||
|
||||
args.session_state = &session_state_;
|
||||
args.tensor_store = &run_state.tensor_store;
|
||||
args.step_container = &run_state.step_container;
|
||||
if (LogMemory::IsEnabled()) {
|
||||
LogMemory::RecordStep(args.step_id, run_state_args.handle);
|
||||
}
|
||||
args.sync_on_finish = sync_on_finish_;
|
||||
|
||||
const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
|
||||
@ -569,6 +514,14 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
}
|
||||
}
|
||||
|
||||
if (run_options.inter_op_thread_pool() < 0 ||
|
||||
run_options.inter_op_thread_pool() >= thread_pools_.size()) {
|
||||
run_state.executors_done.Notify();
|
||||
delete barrier;
|
||||
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
|
||||
run_options.inter_op_thread_pool());
|
||||
}
|
||||
|
||||
// Register this step with session's cancellation manager, so that
|
||||
// `Session::Close()` will cancel the step.
|
||||
const CancellationToken cancellation_token =
|
||||
@ -586,6 +539,9 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
return errors::Cancelled("Run call was cancelled");
|
||||
}
|
||||
|
||||
thread::ThreadPool* pool =
|
||||
thread_pools_[run_options.inter_op_thread_pool()].first;
|
||||
|
||||
Executor::Args::Runner default_runner = [this,
|
||||
pool](Executor::Args::Closure c) {
|
||||
SchedClosure(pool, std::move(c));
|
||||
@ -628,6 +584,111 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
TF_RETURN_IF_ERROR(run_state.status);
|
||||
}
|
||||
|
||||
// Save the output tensors of this run we choose to keep.
|
||||
if (!run_state.tensor_store.empty()) {
|
||||
TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
|
||||
{executors_and_keys->callable_options.fetch().begin(),
|
||||
executors_and_keys->callable_options.fetch().end()},
|
||||
&session_state_));
|
||||
}
|
||||
|
||||
if (args.stats_collector) {
|
||||
args.stats_collector->Finalize();
|
||||
}
|
||||
|
||||
// Build and return the cost model as instructed.
|
||||
if (update_cost_model) {
|
||||
// Build the cost model
|
||||
std::unordered_map<string, const Graph*> device_to_graph;
|
||||
for (const PerPartitionExecutorsAndLib& partition :
|
||||
executors_and_keys->items) {
|
||||
const Graph* graph = partition.graph;
|
||||
const string device = partition.flib->device()->name();
|
||||
device_to_graph[device] = graph;
|
||||
}
|
||||
|
||||
mutex_lock l(executor_lock_);
|
||||
args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
|
||||
|
||||
// annotate stats onto cost graph.
|
||||
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
|
||||
for (const auto& item : executors_and_keys->items) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
|
||||
}
|
||||
}
|
||||
|
||||
// If requested via RunOptions, output the partition graphs.
|
||||
if (run_options.output_partition_graphs()) {
|
||||
protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
|
||||
run_metadata->mutable_partition_graphs();
|
||||
for (const PerPartitionExecutorsAndLib& exec_and_lib :
|
||||
executors_and_keys->items) {
|
||||
GraphDef* partition_graph_def = partition_graph_defs->Add();
|
||||
exec_and_lib.graph->ToGraphDef(partition_graph_def);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DirectSession::Run(const RunOptions& run_options,
|
||||
const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
RunMetadata* run_metadata) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
|
||||
direct_session_runs->GetCell()->IncrementBy(1);
|
||||
|
||||
// Extract the inputs names for this run of the session.
|
||||
std::vector<string> input_tensor_names;
|
||||
input_tensor_names.reserve(inputs.size());
|
||||
for (const auto& it : inputs) {
|
||||
input_tensor_names.push_back(it.first);
|
||||
}
|
||||
|
||||
// Check if we already have an executor for these arguments.
|
||||
ExecutorsAndKeys* executors_and_keys;
|
||||
RunStateArgs run_state_args(run_options.debug_options());
|
||||
|
||||
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
|
||||
target_nodes, &executors_and_keys,
|
||||
&run_state_args));
|
||||
|
||||
// Configure a call frame for the step, which we use to feed and
|
||||
// fetch values to and from the executors.
|
||||
FunctionCallFrame call_frame(executors_and_keys->input_types,
|
||||
executors_and_keys->output_types);
|
||||
gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
|
||||
for (const auto& it : inputs) {
|
||||
if (it.second.dtype() == DT_RESOURCE) {
|
||||
Tensor tensor_from_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ResourceHandleToInputTensor(it.second, &tensor_from_handle));
|
||||
feed_args[executors_and_keys->input_name_to_index[it.first]] =
|
||||
tensor_from_handle;
|
||||
} else {
|
||||
feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
|
||||
}
|
||||
}
|
||||
const Status s = call_frame.SetArgs(feed_args);
|
||||
if (errors::IsInternal(s)) {
|
||||
return errors::InvalidArgument(s.error_message());
|
||||
} else if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
|
||||
const int64 step_id = step_id_counter_.fetch_add(1);
|
||||
|
||||
if (LogMemory::IsEnabled()) {
|
||||
LogMemory::RecordStep(step_id, run_state_args.handle);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
|
||||
executors_and_keys, run_metadata));
|
||||
|
||||
// Receive outputs.
|
||||
if (outputs) {
|
||||
std::vector<Tensor> sorted_outputs;
|
||||
@ -667,45 +728,6 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
}
|
||||
}
|
||||
|
||||
// Save the output tensors of this run we choose to keep.
|
||||
TF_RETURN_IF_ERROR(
|
||||
run_state.tensor_store.SaveTensors(output_names, &session_state_));
|
||||
if (args.stats_collector) {
|
||||
args.stats_collector->Finalize();
|
||||
}
|
||||
|
||||
// Build and return the cost model as instructed.
|
||||
mutex_lock l(executor_lock_);
|
||||
if (update_cost_model) {
|
||||
// Build the cost model
|
||||
std::unordered_map<string, const Graph*> device_to_graph;
|
||||
for (const PerPartitionExecutorsAndLib& partition :
|
||||
executors_and_keys->items) {
|
||||
const Graph* graph = partition.graph;
|
||||
const string device = partition.flib->device()->name();
|
||||
device_to_graph[device] = graph;
|
||||
}
|
||||
args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
|
||||
|
||||
// annotate stats onto cost graph.
|
||||
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
|
||||
for (const auto& item : executors_and_keys->items) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
|
||||
}
|
||||
}
|
||||
|
||||
// If requested via RunOptions, output the partition graphs.
|
||||
if (run_options.output_partition_graphs()) {
|
||||
protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
|
||||
run_metadata->mutable_partition_graphs();
|
||||
for (const PerPartitionExecutorsAndLib& exec_and_lib :
|
||||
executors_and_keys->items) {
|
||||
GraphDef* partition_graph_def = partition_graph_defs->Add();
|
||||
exec_and_lib.graph->ToGraphDef(partition_graph_def);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -714,13 +736,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
string* handle) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
{
|
||||
mutex_lock l(graph_def_lock_);
|
||||
if (!graph_created_) {
|
||||
return errors::InvalidArgument(
|
||||
"Session was not created with a graph before PRunSetup()!");
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
|
||||
|
||||
// RunOptions is not available in PRunSetup, so use thread pool 0.
|
||||
thread::ThreadPool* pool = thread_pools_[0].first;
|
||||
@ -1061,92 +1077,31 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DirectSession::GetOrCreateExecutors(
|
||||
gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
|
||||
gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
|
||||
Status DirectSession::CreateExecutors(
|
||||
const CallableOptions& callable_options,
|
||||
std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
|
||||
std::unique_ptr<FunctionInfo>* out_func_info,
|
||||
RunStateArgs* run_state_args) {
|
||||
int64 handle_name_counter_value = -1;
|
||||
if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
|
||||
handle_name_counter_value = handle_name_counter_.fetch_add(1);
|
||||
}
|
||||
|
||||
string debug_tensor_watches_summary;
|
||||
if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
|
||||
debug_tensor_watches_summary = SummarizeDebugTensorWatches(
|
||||
run_state_args->debug_options.debug_tensor_watch_opts());
|
||||
}
|
||||
|
||||
// Fast lookup path, no sorting.
|
||||
const string key = strings::StrCat(
|
||||
str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
|
||||
str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
|
||||
"/", debug_tensor_watches_summary);
|
||||
// Set the handle, if it's needed to log memory or for partial run.
|
||||
if (handle_name_counter_value >= 0) {
|
||||
run_state_args->handle =
|
||||
strings::StrCat(key, ";", handle_name_counter_value);
|
||||
}
|
||||
|
||||
// See if we already have the executors for this run.
|
||||
{
|
||||
mutex_lock l(executor_lock_); // could use reader lock
|
||||
auto it = executors_.find(key);
|
||||
if (it != executors_.end()) {
|
||||
*executors_and_keys = it->second.get();
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Slow lookup path, the unsorted key missed the cache.
|
||||
// Sort the inputs and outputs, and look up with the sorted key in case an
|
||||
// earlier call used a different order of inputs and outputs.
|
||||
//
|
||||
// We could consider some other signature instead of sorting that
|
||||
// preserves the same property to avoid the sort in the future.
|
||||
std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
|
||||
std::sort(inputs_sorted.begin(), inputs_sorted.end());
|
||||
std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
|
||||
std::sort(outputs_sorted.begin(), outputs_sorted.end());
|
||||
std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
|
||||
std::sort(tn_sorted.begin(), tn_sorted.end());
|
||||
|
||||
const string sorted_key = strings::StrCat(
|
||||
str_util::Join(inputs_sorted, ","), "->",
|
||||
str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
|
||||
"/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
|
||||
// Set the handle, if its needed to log memory or for partial run.
|
||||
if (handle_name_counter_value >= 0) {
|
||||
run_state_args->handle =
|
||||
strings::StrCat(sorted_key, ";", handle_name_counter_value);
|
||||
}
|
||||
|
||||
// See if we already have the executors for this run.
|
||||
{
|
||||
mutex_lock l(executor_lock_);
|
||||
auto it = executors_.find(sorted_key);
|
||||
if (it != executors_.end()) {
|
||||
*executors_and_keys = it->second.get();
|
||||
// Insert this under the original key.
|
||||
executors_.emplace(key, it->second);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing found, so create the executors and store in the cache.
|
||||
BuildGraphOptions options;
|
||||
options.feed_endpoints = inputs_sorted;
|
||||
options.fetch_endpoints = outputs_sorted;
|
||||
options.target_nodes = tn_sorted;
|
||||
options.feed_endpoints = std::vector<string>(callable_options.feed().begin(),
|
||||
callable_options.feed().end());
|
||||
options.fetch_endpoints = std::vector<string>(
|
||||
callable_options.fetch().begin(), callable_options.fetch().end());
|
||||
options.target_nodes = std::vector<string>(callable_options.target().begin(),
|
||||
callable_options.target().end());
|
||||
options.use_function_convention = !run_state_args->is_partial_run;
|
||||
if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
|
||||
options.debug_options = run_state_args->debug_options;
|
||||
if (!callable_options.run_options()
|
||||
.debug_options()
|
||||
.debug_tensor_watch_opts()
|
||||
.empty()) {
|
||||
options.debug_options = callable_options.run_options().debug_options();
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
|
||||
std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
|
||||
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
|
||||
|
||||
ek->callable_options = callable_options;
|
||||
|
||||
// The executor_lock_ is intentionally released while executor is
|
||||
// being created.
|
||||
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
|
||||
TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
|
||||
run_state_args, &ek->input_types,
|
||||
@ -1155,11 +1110,11 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
if (run_state_args->is_partial_run) {
|
||||
ek->graph = std::move(run_state_args->graph);
|
||||
std::unordered_set<StringPiece, StringPieceHasher> names;
|
||||
for (const string& input : inputs) {
|
||||
for (const string& input : callable_options.feed()) {
|
||||
TensorId id(ParseTensorName(input));
|
||||
names.emplace(id.first);
|
||||
}
|
||||
for (const string& output : outputs) {
|
||||
for (const string& output : callable_options.fetch()) {
|
||||
TensorId id(ParseTensorName(output));
|
||||
names.emplace(id.first);
|
||||
}
|
||||
@ -1260,12 +1215,12 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
// For regular `Run()`, we use the function calling convention, and so
|
||||
// maintain a mapping from input/output names to
|
||||
// argument/return-value ordinal index.
|
||||
for (size_t i = 0; i < inputs_sorted.size(); ++i) {
|
||||
const string& input = inputs_sorted[i];
|
||||
for (int i = 0; i < callable_options.feed().size(); ++i) {
|
||||
const string& input = callable_options.feed(i);
|
||||
ek->input_name_to_index[input] = i;
|
||||
}
|
||||
for (size_t i = 0; i < outputs_sorted.size(); ++i) {
|
||||
const string& output = outputs_sorted[i];
|
||||
for (int i = 0; i < callable_options.fetch().size(); ++i) {
|
||||
const string& output = callable_options.fetch(i);
|
||||
ek->output_name_to_index[output] = i;
|
||||
}
|
||||
} else {
|
||||
@ -1274,26 +1229,123 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
//
|
||||
// We always use the first device as the device name portion of the
|
||||
// key, even if we're feeding another graph.
|
||||
for (size_t i = 0; i < inputs_sorted.size(); ++i) {
|
||||
const string& input = inputs_sorted[i];
|
||||
for (int i = 0; i < callable_options.feed().size(); ++i) {
|
||||
const string& input = callable_options.feed(i);
|
||||
ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
|
||||
input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
|
||||
}
|
||||
for (size_t i = 0; i < outputs_sorted.size(); ++i) {
|
||||
const string& output = outputs_sorted[i];
|
||||
for (int i = 0; i < callable_options.fetch().size(); ++i) {
|
||||
const string& output = callable_options.fetch(i);
|
||||
ek->output_name_to_rendezvous_key[output] =
|
||||
GetRendezvousKey(output, device_set_.client_device()->attributes(),
|
||||
FrameAndIter(0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
*out_executors_and_keys = std::move(ek);
|
||||
*out_func_info = std::move(func_info);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DirectSession::GetOrCreateExecutors(
|
||||
gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
|
||||
gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
|
||||
RunStateArgs* run_state_args) {
|
||||
int64 handle_name_counter_value = -1;
|
||||
if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
|
||||
handle_name_counter_value = handle_name_counter_.fetch_add(1);
|
||||
}
|
||||
|
||||
string debug_tensor_watches_summary;
|
||||
if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
|
||||
debug_tensor_watches_summary = SummarizeDebugTensorWatches(
|
||||
run_state_args->debug_options.debug_tensor_watch_opts());
|
||||
}
|
||||
|
||||
// Fast lookup path, no sorting.
|
||||
const string key = strings::StrCat(
|
||||
str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
|
||||
str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
|
||||
"/", debug_tensor_watches_summary);
|
||||
// Set the handle, if it's needed to log memory or for partial run.
|
||||
if (handle_name_counter_value >= 0) {
|
||||
run_state_args->handle =
|
||||
strings::StrCat(key, ";", handle_name_counter_value);
|
||||
}
|
||||
|
||||
// See if we already have the executors for this run.
|
||||
{
|
||||
mutex_lock l(executor_lock_); // could use reader lock
|
||||
auto it = executors_.find(key);
|
||||
if (it != executors_.end()) {
|
||||
*executors_and_keys = it->second.get();
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Slow lookup path, the unsorted key missed the cache.
|
||||
// Sort the inputs and outputs, and look up with the sorted key in case an
|
||||
// earlier call used a different order of inputs and outputs.
|
||||
//
|
||||
// We could consider some other signature instead of sorting that
|
||||
// preserves the same property to avoid the sort in the future.
|
||||
std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
|
||||
std::sort(inputs_sorted.begin(), inputs_sorted.end());
|
||||
std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
|
||||
std::sort(outputs_sorted.begin(), outputs_sorted.end());
|
||||
std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
|
||||
std::sort(tn_sorted.begin(), tn_sorted.end());
|
||||
|
||||
const string sorted_key = strings::StrCat(
|
||||
str_util::Join(inputs_sorted, ","), "->",
|
||||
str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
|
||||
"/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
|
||||
// Set the handle, if its needed to log memory or for partial run.
|
||||
if (handle_name_counter_value >= 0) {
|
||||
run_state_args->handle =
|
||||
strings::StrCat(sorted_key, ";", handle_name_counter_value);
|
||||
}
|
||||
|
||||
// See if we already have the executors for this run.
|
||||
{
|
||||
mutex_lock l(executor_lock_);
|
||||
auto it = executors_.find(sorted_key);
|
||||
if (it != executors_.end()) {
|
||||
*executors_and_keys = it->second.get();
|
||||
// Insert this under the original key.
|
||||
executors_.emplace(key, it->second);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing found, so create the executors and store in the cache.
|
||||
// The executor_lock_ is intentionally released while executors are
|
||||
// being created.
|
||||
CallableOptions callable_options;
|
||||
for (const string& input : inputs_sorted) {
|
||||
callable_options.add_feed(input);
|
||||
}
|
||||
for (const string& output : outputs_sorted) {
|
||||
callable_options.add_fetch(output);
|
||||
}
|
||||
for (const string& target : tn_sorted) {
|
||||
callable_options.add_target(target);
|
||||
}
|
||||
*callable_options.mutable_run_options()->mutable_debug_options() =
|
||||
run_state_args->debug_options;
|
||||
std::unique_ptr<ExecutorsAndKeys> ek;
|
||||
std::unique_ptr<FunctionInfo> func_info;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateExecutors(callable_options, &ek, &func_info, run_state_args));
|
||||
|
||||
// Reacquire the lock, try to insert into the map.
|
||||
mutex_lock l(executor_lock_);
|
||||
functions_.push_back(std::move(func_info));
|
||||
|
||||
// Another thread may have created the entry before us, in which case we will
|
||||
// reuse the already created one.
|
||||
auto insert_result = executors_.emplace(sorted_key, ek);
|
||||
auto insert_result = executors_.emplace(
|
||||
sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
|
||||
// Insert the value under the original key, so the fast path lookup will work
|
||||
// if the user uses the same order of inputs, outputs, and targets again.
|
||||
executors_.emplace(key, insert_result.first->second);
|
||||
@ -1562,4 +1614,156 @@ void DirectSession::WaitForNotification(RunState* run_state,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DirectSession::MakeCallable(const CallableOptions& callable_options,
|
||||
CallableHandle* out_handle) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
|
||||
|
||||
if (!callable_options.run_options()
|
||||
.debug_options()
|
||||
.debug_tensor_watch_opts()
|
||||
.empty()) {
|
||||
return errors::Unimplemented(
|
||||
"Debug options are not currently supported via the C++ MakeCallable "
|
||||
"interface.");
|
||||
}
|
||||
|
||||
std::unique_ptr<ExecutorsAndKeys> ek;
|
||||
std::unique_ptr<FunctionInfo> func_info;
|
||||
RunStateArgs run_state_args(callable_options.run_options().debug_options());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
|
||||
{
|
||||
mutex_lock l(callables_lock_);
|
||||
*out_handle = next_callable_handle_++;
|
||||
callables_[*out_handle] = {std::move(ek), std::move(func_info)};
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class DirectSession::RunCallableCallFrame : public CallFrameInterface {
|
||||
public:
|
||||
RunCallableCallFrame(DirectSession* session,
|
||||
ExecutorsAndKeys* executors_and_keys,
|
||||
const std::vector<Tensor>* feed_tensors,
|
||||
std::vector<Tensor>* fetch_tensors)
|
||||
: session_(session),
|
||||
executors_and_keys_(executors_and_keys),
|
||||
feed_tensors_(feed_tensors),
|
||||
fetch_tensors_(fetch_tensors) {}
|
||||
|
||||
size_t num_args() const override {
|
||||
return executors_and_keys_->input_types.size();
|
||||
}
|
||||
size_t num_retvals() const override {
|
||||
return executors_and_keys_->output_types.size();
|
||||
}
|
||||
|
||||
Status GetArg(int index, Tensor* val) const override {
|
||||
if (index > feed_tensors_->size()) {
|
||||
return errors::Internal("Args index out of bounds: ", index);
|
||||
} else if (executors_and_keys_->input_types[index] == DT_RESOURCE) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
session_->ResourceHandleToInputTensor((*feed_tensors_)[index], val));
|
||||
} else {
|
||||
*val = (*feed_tensors_)[index];
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetRetval(int index, const Tensor& val) override {
|
||||
if (index > fetch_tensors_->size()) {
|
||||
return errors::Internal("RetVal index out of bounds: ", index);
|
||||
}
|
||||
(*fetch_tensors_)[index] = val;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
DirectSession* const session_; // Not owned.
|
||||
ExecutorsAndKeys* const executors_and_keys_; // Not owned.
|
||||
const std::vector<Tensor>* const feed_tensors_; // Not owned.
|
||||
std::vector<Tensor>* const fetch_tensors_; // Not owned.
|
||||
};
|
||||
|
||||
::tensorflow::Status DirectSession::RunCallable(
|
||||
CallableHandle handle, const std::vector<Tensor>& feed_tensors,
|
||||
std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
|
||||
direct_session_runs->GetCell()->IncrementBy(1);
|
||||
|
||||
// Check if we already have an executor for these arguments.
|
||||
std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
|
||||
const int64 step_id = step_id_counter_.fetch_add(1);
|
||||
|
||||
{
|
||||
tf_shared_lock l(callables_lock_);
|
||||
if (handle >= next_callable_handle_) {
|
||||
return errors::InvalidArgument("No such callable handle: ", handle);
|
||||
}
|
||||
executors_and_keys = callables_[handle].executors_and_keys;
|
||||
}
|
||||
|
||||
if (!executors_and_keys) {
|
||||
return errors::InvalidArgument(
|
||||
"Attempted to run callable after handle was released: ", handle);
|
||||
}
|
||||
|
||||
// NOTE(mrry): Debug options are not currently supported in the
|
||||
// callable interface.
|
||||
DebugOptions debug_options;
|
||||
RunStateArgs run_state_args(debug_options);
|
||||
|
||||
// Configure a call frame for the step, which we use to feed and
|
||||
// fetch values to and from the executors.
|
||||
if (feed_tensors.size() != executors_and_keys->input_types.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected ", executors_and_keys->input_types.size(),
|
||||
" feed tensors, but got ", feed_tensors.size());
|
||||
}
|
||||
if (fetch_tensors != nullptr) {
|
||||
fetch_tensors->resize(executors_and_keys->output_types.size());
|
||||
} else if (!executors_and_keys->output_types.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"`fetch_tensors` must be provided when the callable has one or more "
|
||||
"outputs.");
|
||||
}
|
||||
|
||||
// A specialized CallFrame implementation that takes advantage of the
|
||||
// optimized RunCallable interface.
|
||||
|
||||
RunCallableCallFrame call_frame(this, executors_and_keys.get(), &feed_tensors,
|
||||
fetch_tensors);
|
||||
|
||||
if (LogMemory::IsEnabled()) {
|
||||
LogMemory::RecordStep(step_id, run_state_args.handle);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunInternal(step_id, executors_and_keys->callable_options.run_options(),
|
||||
&call_frame, executors_and_keys.get(), run_metadata));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
|
||||
mutex_lock l(callables_lock_);
|
||||
if (handle >= next_callable_handle_) {
|
||||
return errors::InvalidArgument("No such callable handle: ", handle);
|
||||
}
|
||||
callables_.erase(handle);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DirectSession::Callable::~Callable() {
|
||||
// We must delete the fields in this order, because the destructor
|
||||
// of `executors_and_keys` will call into an object owned by
|
||||
// `function_info` (in particular, when deleting a kernel, it relies
|
||||
// on the `FunctionLibraryRuntime` to know if the kernel is stateful
|
||||
// or not).
|
||||
executors_and_keys.reset();
|
||||
function_info.reset();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -107,6 +107,14 @@ class DirectSession : public Session {
|
||||
cost_model_manager_.ExportCostModels(cost_models);
|
||||
}
|
||||
|
||||
::tensorflow::Status MakeCallable(const CallableOptions& callable_options,
|
||||
CallableHandle* out_handle) override;
|
||||
::tensorflow::Status RunCallable(CallableHandle handle,
|
||||
const std::vector<Tensor>& feed_tensors,
|
||||
std::vector<Tensor>* fetch_tensors,
|
||||
RunMetadata* run_metadata) override;
|
||||
::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
|
||||
|
||||
private:
|
||||
// We create one executor and its dependent library runtime for
|
||||
// every partition.
|
||||
@ -139,6 +147,8 @@ class DirectSession : public Session {
|
||||
|
||||
DataTypeVector input_types;
|
||||
DataTypeVector output_types;
|
||||
|
||||
CallableOptions callable_options;
|
||||
};
|
||||
|
||||
// A FunctionInfo object is created for every unique set of feeds/fetches.
|
||||
@ -206,6 +216,14 @@ class DirectSession : public Session {
|
||||
gtl::ArraySlice<string> target_nodes,
|
||||
ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
|
||||
|
||||
// Creates a set of executors to run the subgraph defined by
|
||||
// `callable_options`.
|
||||
::tensorflow::Status CreateExecutors(
|
||||
const CallableOptions& callable_options,
|
||||
std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
|
||||
std::unique_ptr<FunctionInfo>* out_func_info,
|
||||
RunStateArgs* run_state_args);
|
||||
|
||||
// Creates several graphs given the existing graph_def_ and the
|
||||
// input feeds and fetches, given 'devices'. The graphs share a common
|
||||
// function library 'flib_def'.
|
||||
@ -216,6 +234,11 @@ class DirectSession : public Session {
|
||||
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
||||
DataTypeVector* output_types);
|
||||
|
||||
::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
|
||||
CallFrameInterface* call_frame,
|
||||
ExecutorsAndKeys* executors_and_keys,
|
||||
RunMetadata* run_metadata);
|
||||
|
||||
::tensorflow::Status ExtendLocked(const GraphDef& graph)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
|
||||
|
||||
@ -257,11 +280,18 @@ class DirectSession : public Session {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
::tensorflow::Status CheckGraphCreated(const char* method) {
|
||||
mutex_lock l(graph_def_lock_);
|
||||
if (!graph_created_) {
|
||||
return errors::InvalidArgument(
|
||||
"Session was not created with a graph before ", method, "!");
|
||||
}
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
::tensorflow::Status CreateDebuggerState(
|
||||
const DebugOptions& debug_options, int64 session_run_index,
|
||||
int64 executor_step_index, const std::vector<string>& input_names,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_names,
|
||||
const CallableOptions& options, int64 global_step,
|
||||
int64 session_run_index, int64 executor_step_index,
|
||||
std::unique_ptr<DebuggerStateInterface>* debugger_state);
|
||||
|
||||
::tensorflow::Status DecorateAndPublishGraphForDebug(
|
||||
@ -303,6 +333,16 @@ class DirectSession : public Session {
|
||||
std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
|
||||
GUARDED_BY(executor_lock_);
|
||||
|
||||
class RunCallableCallFrame;
|
||||
struct Callable {
|
||||
std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
|
||||
std::shared_ptr<FunctionInfo> function_info;
|
||||
~Callable();
|
||||
};
|
||||
mutex callables_lock_;
|
||||
int64 next_callable_handle_ GUARDED_BY(callables_lock_) = 0;
|
||||
std::unordered_map<int64, Callable> callables_ GUARDED_BY(callables_lock_);
|
||||
|
||||
// Holds mappings from handle to partial run state.
|
||||
std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
|
||||
GUARDED_BY(executor_lock_);
|
||||
|
@ -49,6 +49,22 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
CallableOptions MakeCallableOptions(gtl::ArraySlice<string> feeds,
|
||||
gtl::ArraySlice<string> fetches,
|
||||
gtl::ArraySlice<string> targets) {
|
||||
CallableOptions ret;
|
||||
for (const string& feed : feeds) {
|
||||
ret.add_feed(feed);
|
||||
}
|
||||
for (const string& fetch : fetches) {
|
||||
ret.add_fetch(fetch);
|
||||
}
|
||||
for (const string& target : targets) {
|
||||
ret.add_target(target);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::unique_ptr<Session> CreateSession() {
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
@ -111,6 +127,53 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
|
||||
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
|
||||
}
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
|
||||
Initialize({3, 2, -1, 0});
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
|
||||
// Run the test twice to ensure that the Make/Run/Release cycle is hermetic.
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
// Request two targets: one fetch output and one non-fetched output.
|
||||
Session::CallableHandle handle;
|
||||
TF_ASSERT_OK(session->MakeCallable(
|
||||
MakeCallableOptions({}, {y_ + ":0"}, {y_neg_}), &handle));
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
|
||||
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
// The first output should be initialized and have the correct
|
||||
// output.
|
||||
auto mat = outputs[0].matrix<float>();
|
||||
ASSERT_TRUE(outputs[0].IsInitialized());
|
||||
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
|
||||
}
|
||||
|
||||
Status s = session->RunCallable(handle, {}, nullptr, nullptr);
|
||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||
EXPECT_TRUE(StringPiece(s.error_message())
|
||||
.contains("`fetch_tensors` must be provided"));
|
||||
|
||||
TF_ASSERT_OK(session->ReleaseCallable(handle));
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
s = session->RunCallable(handle, {}, &outputs, nullptr);
|
||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||
EXPECT_TRUE(
|
||||
StringPiece(s.error_message())
|
||||
.contains("Attempted to run callable after handle was released"));
|
||||
|
||||
s = session->RunCallable(handle + 1, {}, &outputs, nullptr);
|
||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||
EXPECT_TRUE(
|
||||
StringPiece(s.error_message()).contains("No such callable handle"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, TestFeed) {
|
||||
Initialize({1, 2, 3, 4});
|
||||
auto session = CreateSession();
|
||||
@ -140,6 +203,39 @@ TEST_F(DirectSessionMinusAXTest, TestFeed) {
|
||||
EXPECT_FLOAT_EQ(39.0, mat(1, 0));
|
||||
}
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, TestFeed_Callable) {
|
||||
Initialize({1, 2, 3, 4});
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
|
||||
// Fill in the input and ask for the output
|
||||
//
|
||||
// Note that the input being fed is on the second device.
|
||||
CallableOptions callable_options;
|
||||
callable_options.add_feed(x_);
|
||||
callable_options.add_fetch(y_ + ":0");
|
||||
Session::CallableHandle handle;
|
||||
TF_ASSERT_OK(session->MakeCallable(MakeCallableOptions({x_}, {y_ + ":0"}, {}),
|
||||
&handle));
|
||||
Tensor t(DT_FLOAT, TensorShape({2, 1}));
|
||||
t.matrix<float>()(0, 0) = 5;
|
||||
t.matrix<float>()(1, 0) = 6;
|
||||
std::vector<Tensor> inputs = {t};
|
||||
std::vector<Tensor> outputs;
|
||||
|
||||
// Run the callable
|
||||
TF_ASSERT_OK(session->RunCallable(handle, inputs, &outputs, nullptr));
|
||||
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
auto mat = outputs[0].matrix<float>();
|
||||
|
||||
// Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
|
||||
EXPECT_FLOAT_EQ(17.0, mat(0, 0));
|
||||
EXPECT_FLOAT_EQ(39.0, mat(1, 0));
|
||||
}
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
|
||||
Initialize({1, 2, 3, 4});
|
||||
auto session = CreateSession();
|
||||
@ -172,6 +268,39 @@ TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
|
||||
delete tp;
|
||||
}
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, TestConcurrency_Callable) {
|
||||
Initialize({1, 2, 3, 4});
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
|
||||
// Fill in the input and ask for the output
|
||||
thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
|
||||
|
||||
Session::CallableHandle handle;
|
||||
TF_ASSERT_OK(
|
||||
session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle));
|
||||
|
||||
// Run the callable 1000 times in 4 different threads concurrently.
|
||||
auto fn = [&session, handle]() {
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
std::vector<Tensor> outputs;
|
||||
// Run the graph
|
||||
TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
auto mat = outputs[0].matrix<float>();
|
||||
EXPECT_FLOAT_EQ(3.0, mat(0, 0));
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
tp->Schedule(fn);
|
||||
}
|
||||
|
||||
// Wait for the functions to finish.
|
||||
delete tp;
|
||||
}
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) {
|
||||
Initialize({1, 2, 3, 4});
|
||||
|
||||
@ -297,6 +426,38 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
|
||||
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
|
||||
Initialize({3, 2, -1, 0});
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
|
||||
// Request two targets: one fetch output and one non-fetched output.
|
||||
Session::CallableHandle handle;
|
||||
CallableOptions callable_options =
|
||||
MakeCallableOptions({}, {y_ + ":0"}, {y_neg_});
|
||||
callable_options.mutable_run_options()->set_trace_level(
|
||||
RunOptions::FULL_TRACE);
|
||||
TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
|
||||
|
||||
RunMetadata run_metadata;
|
||||
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0);
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, &run_metadata));
|
||||
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
// The first output should be initialized and have the correct
|
||||
// output.
|
||||
auto mat = outputs[0].matrix<float>();
|
||||
ASSERT_TRUE(outputs[0].IsInitialized());
|
||||
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
|
||||
|
||||
// Checks RunMetadata is well-formed
|
||||
ASSERT_TRUE(run_metadata.has_step_stats());
|
||||
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
|
||||
GraphDef def;
|
||||
Graph g(OpRegistry::Global());
|
||||
@ -409,6 +570,89 @@ TEST(DirectSessionTest, MultipleFeedTest) {
|
||||
EXPECT_TRUE(StringPiece(s.error_message()).contains("fed more than once"));
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, MultipleFeedTest_Callable) {
|
||||
GraphDef def;
|
||||
Graph g(OpRegistry::Global());
|
||||
|
||||
Tensor first_value(DT_FLOAT, TensorShape({}));
|
||||
first_value.scalar<float>()() = 1.0;
|
||||
Node* first_const = test::graph::Constant(&g, first_value);
|
||||
Node* first_identity = test::graph::Identity(&g, first_const);
|
||||
|
||||
Tensor second_value(DT_FLOAT, TensorShape({}));
|
||||
second_value.scalar<float>()() = 2.0;
|
||||
Node* second_const = test::graph::Constant(&g, second_value);
|
||||
Node* second_identity = test::graph::Identity(&g, second_const);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
Session::CallableHandle handle;
|
||||
std::vector<Tensor> outputs;
|
||||
|
||||
// Fetch without feeding.
|
||||
TF_ASSERT_OK(session->MakeCallable(
|
||||
MakeCallableOptions(
|
||||
{}, {first_identity->name() + ":0", second_identity->name() + ":0"},
|
||||
{}),
|
||||
&handle));
|
||||
TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
|
||||
ASSERT_EQ(2, outputs.size());
|
||||
ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
|
||||
ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
|
||||
|
||||
TF_ASSERT_OK(session->MakeCallable(
|
||||
MakeCallableOptions(
|
||||
{}, {second_identity->name() + ":0", first_identity->name() + ":0"},
|
||||
{}),
|
||||
&handle));
|
||||
TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
|
||||
ASSERT_EQ(2, outputs.size());
|
||||
ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
|
||||
ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
|
||||
|
||||
Tensor value_11(DT_FLOAT, TensorShape({}));
|
||||
value_11.scalar<float>()() = 11.0;
|
||||
Tensor value_22(DT_FLOAT, TensorShape({}));
|
||||
value_22.scalar<float>()() = 22.0;
|
||||
|
||||
// Feed [first_const, second_const]
|
||||
TF_ASSERT_OK(session->MakeCallable(
|
||||
MakeCallableOptions(
|
||||
{first_const->name(), second_const->name()},
|
||||
{first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
|
||||
&handle));
|
||||
TF_ASSERT_OK(
|
||||
session->RunCallable(handle, {value_11, value_22}, &outputs, nullptr));
|
||||
ASSERT_EQ(2, outputs.size());
|
||||
ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
|
||||
ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
|
||||
|
||||
// Feed [second_const, first_const]
|
||||
TF_ASSERT_OK(session->MakeCallable(
|
||||
MakeCallableOptions(
|
||||
{second_const->name(), first_const->name()},
|
||||
{first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
|
||||
&handle));
|
||||
TF_ASSERT_OK(
|
||||
session->RunCallable(handle, {value_22, value_11}, &outputs, nullptr));
|
||||
ASSERT_EQ(2, outputs.size());
|
||||
ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
|
||||
ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
|
||||
|
||||
// Feed [first_const, first_const]
|
||||
Status s = session->MakeCallable(
|
||||
MakeCallableOptions(
|
||||
{first_const->name(), first_const->name()},
|
||||
{first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
|
||||
&handle);
|
||||
EXPECT_TRUE(errors::IsInvalidArgument(s));
|
||||
EXPECT_TRUE(StringPiece(s.error_message()).contains("fed more than once"));
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, FetchMultipleTimes) {
|
||||
Graph g(OpRegistry::Global());
|
||||
Tensor seven_tensor(DT_INT32, TensorShape());
|
||||
@ -695,6 +939,59 @@ TEST(DirectSessionTest, RunHandleTest) {
|
||||
ASSERT_TRUE(s.ok());
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, RunHandleTest_Callable) {
|
||||
GraphDef def;
|
||||
Graph g(OpRegistry::Global());
|
||||
|
||||
Tensor value0(DT_FLOAT, TensorShape({}));
|
||||
value0.scalar<float>()() = 1.0;
|
||||
Node* const0 = test::graph::Constant(&g, value0);
|
||||
Node* identity0 = test::graph::Identity(&g, const0);
|
||||
|
||||
Tensor value1(DT_FLOAT, TensorShape({}));
|
||||
value1.scalar<float>()() = 2.0;
|
||||
Node* const1 = test::graph::Constant(&g, value1);
|
||||
Node* node3 = test::graph::Add(&g, identity0, const1);
|
||||
Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
|
||||
|
||||
Tensor value2(DT_STRING, TensorShape({}));
|
||||
Node* const2 = test::graph::Constant(&g, value2);
|
||||
Node* node5 = test::graph::GetSessionTensor(&g, const2);
|
||||
Node* node6 = test::graph::Add(&g, node5, const1);
|
||||
|
||||
Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
// First run call: Create a handle.
|
||||
std::vector<Tensor> outputs;
|
||||
Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
|
||||
ASSERT_TRUE(s.ok());
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
|
||||
ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
|
||||
Tensor string_handle(DT_STRING, {});
|
||||
string_handle.flat<string>().setConstant(resource_handle.name());
|
||||
|
||||
// Second run call: Use a handle.
|
||||
std::vector<Tensor> outputs1;
|
||||
s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
|
||||
{}, &outputs1);
|
||||
ASSERT_TRUE(s.ok());
|
||||
ASSERT_EQ(1, outputs1.size());
|
||||
ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
|
||||
|
||||
// Third run call: Delete a handle.
|
||||
std::vector<Tensor> outputs2;
|
||||
s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
|
||||
&outputs2);
|
||||
ASSERT_TRUE(s.ok());
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) {
|
||||
Graph graph(OpRegistry::Global());
|
||||
|
||||
@ -1109,6 +1406,11 @@ TEST(DirectSessionTest, TestDirectSessionRunClose) {
|
||||
EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
|
||||
outputs.clear();
|
||||
|
||||
// Make a callable handle before closing the session.
|
||||
Session::CallableHandle handle;
|
||||
TF_ASSERT_OK(session->MakeCallable(
|
||||
MakeCallableOptions({}, {}, {var_assign->name()}), &handle));
|
||||
|
||||
// Close the session.
|
||||
TF_ASSERT_OK(session->Close());
|
||||
|
||||
@ -1116,6 +1418,10 @@ TEST(DirectSessionTest, TestDirectSessionRunClose) {
|
||||
Status s = session->Run({} /* inputs */, {},
|
||||
{var_assign->name()} /* target_nodes */, nullptr);
|
||||
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||
|
||||
// Run the read as a callable to verify that we get the same error.
|
||||
s = session->RunCallable(handle, {}, {}, nullptr);
|
||||
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, TestDirectSessionPRunClose) {
|
||||
@ -1217,7 +1523,8 @@ TEST(DirectSessionTest, LocalDeviceManager) {
|
||||
|
||||
// A simple benchmark for the overhead of `DirectSession::Run()` calls
|
||||
// with varying numbers of feeds/fetches.
|
||||
void FeedFetchBenchmarkHelper(int iters, int num_feeds) {
|
||||
void FeedFetchBenchmarkHelper(int iters, int num_feeds,
|
||||
bool use_make_callable) {
|
||||
testing::StopTiming();
|
||||
|
||||
Tensor value(DT_FLOAT, TensorShape());
|
||||
@ -1253,29 +1560,55 @@ void FeedFetchBenchmarkHelper(int iters, int num_feeds) {
|
||||
SessionOptions opts;
|
||||
std::unique_ptr<Session> session(NewSession(opts));
|
||||
TF_CHECK_OK(session->Create(gd));
|
||||
{
|
||||
// NOTE(mrry): Ignore the first run, which will incur the graph
|
||||
// partitioning/pruning overhead and skew the results.
|
||||
//
|
||||
// Note that we should also optimize and monitor the overhead on
|
||||
// the first run, which will impact application startup times, but
|
||||
// that is not the object of study in this benchmark.
|
||||
std::vector<Tensor> output_values;
|
||||
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
|
||||
if (use_make_callable) {
|
||||
Session::CallableHandle handle;
|
||||
CallableOptions callable_options;
|
||||
std::vector<Tensor> input_tensors;
|
||||
for (const auto& input : inputs) {
|
||||
callable_options.add_feed(input.first);
|
||||
input_tensors.push_back(input.second);
|
||||
}
|
||||
for (const string& output : outputs) {
|
||||
callable_options.add_fetch(output);
|
||||
}
|
||||
TF_CHECK_OK(session->MakeCallable(callable_options, &handle));
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
std::vector<Tensor> output_values;
|
||||
TF_CHECK_OK(
|
||||
session->RunCallable(handle, input_tensors, &output_values, nullptr));
|
||||
}
|
||||
testing::StopTiming();
|
||||
} else {
|
||||
{
|
||||
// NOTE(mrry): Ignore the first run, which will incur the graph
|
||||
// partitioning/pruning overhead and skew the results.
|
||||
//
|
||||
// Note that we should also optimize and monitor the overhead on
|
||||
// the first run, which will impact application startup times, but
|
||||
// that is not the object of study in this benchmark.
|
||||
std::vector<Tensor> output_values;
|
||||
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
|
||||
}
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
std::vector<Tensor> output_values;
|
||||
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
std::vector<Tensor> output_values;
|
||||
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
void BM_FeedFetch(int iters, int num_feeds) {
|
||||
FeedFetchBenchmarkHelper(iters, num_feeds);
|
||||
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ false);
|
||||
}
|
||||
void BM_FeedFetchCallable(int iters, int num_feeds) {
|
||||
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
|
||||
BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -74,6 +74,12 @@ class TensorStore {
|
||||
Status SaveTensors(const std::vector<string>& output_names,
|
||||
SessionState* session_state);
|
||||
|
||||
// Returns true if no tensors have been added to this store.
|
||||
bool empty() {
|
||||
mutex_lock l(lock_);
|
||||
return tensors_.empty();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex lock_;
|
||||
|
||||
|
@ -410,3 +410,26 @@ message RunMetadata {
|
||||
// Graphs of the partitions executed by executors.
|
||||
repeated GraphDef partition_graphs = 3;
|
||||
}
|
||||
|
||||
// Defines a subgraph in another `GraphDef` as a set of feed points and nodes
|
||||
// to be fetched or executed.
|
||||
//
|
||||
// Compare with the arguments to `Session::Run()`.
|
||||
message CallableOptions {
|
||||
// Tensors to be fed in the callable. Each feed is the name of a tensor.
|
||||
repeated string feed = 1;
|
||||
|
||||
// Fetches. A list of tensor names. The caller of the callable expects a
|
||||
// tensor to be returned for each fetch[i] (see RunStepResponse.tensor). The
|
||||
// order of specified fetches does not change the execution order.
|
||||
repeated string fetch = 2;
|
||||
|
||||
// Target Nodes. A list of node names. The named nodes will be run by the
|
||||
// callable but their outputs will not be returned.
|
||||
repeated string target = 3;
|
||||
|
||||
// Options that will be applied to each run.
|
||||
RunOptions run_options = 4;
|
||||
|
||||
// Next: 5
|
||||
}
|
||||
|
@ -195,6 +195,41 @@ class Session {
|
||||
return errors::Unimplemented(
|
||||
"LocalDeviceManager is not supported for this session.");
|
||||
}
|
||||
|
||||
/// \brief A handle to a subgraph, created with `Session::MakeCallable()`.
|
||||
typedef int64 CallableHandle;
|
||||
|
||||
/// \brief Creates a `handle` for invoking the subgraph defined by
|
||||
/// `callable_options`.
|
||||
/// NOTE: This API is still experimental and may change.
|
||||
virtual Status MakeCallable(const CallableOptions& callable_options,
|
||||
CallableHandle* out_handle) {
|
||||
return errors::Unimplemented(
|
||||
"MakeCallable is not supported for this session.");
|
||||
}
|
||||
|
||||
/// \brief Invokes the subgraph named by `handle` with the given options and
|
||||
/// input tensors.
|
||||
///
|
||||
/// The order of tensors in `feed_tensors` must and `fetch_tensors` will
|
||||
/// match the order of names in `CallableOptions::feed()` and
|
||||
/// `CallableOptions::fetch()` when this subgraph was created.
|
||||
/// NOTE: This API is still experimental and may change.
|
||||
virtual Status RunCallable(CallableHandle handle,
|
||||
const std::vector<Tensor>& feed_tensors,
|
||||
std::vector<Tensor>* fetch_tensors,
|
||||
RunMetadata* run_metadata) {
|
||||
return errors::Unimplemented(
|
||||
"RunCallable is not supported for this session.");
|
||||
}
|
||||
|
||||
/// \brief Releases resources associated with the given `handle` in this
|
||||
/// session.
|
||||
/// NOTE: This API is still experimental and may change.
|
||||
virtual Status ReleaseCallable(CallableHandle handle) {
|
||||
return errors::Unimplemented(
|
||||
"ReleaseCallable is not supported for this session.");
|
||||
}
|
||||
};
|
||||
|
||||
/// \brief Create a new session with the given options.
|
||||
|
Loading…
x
Reference in New Issue
Block a user