Add experimental Session::MakeCallable() API and implement it for DirectSession.

The intent of this new API matches the Python `tf.Session.make_callable()`
method: it splits the two roles of the `Session::Run()` method into separate
methods:

1. `Session::MakeCallable()` takes information about a subgraph (such as the
   names of nodes to feed and fetch), and prunes and optimizes that graph,
   returning a simple handle.
2. `Session::RunCallable()` takes that handle, plus any values to be fed,
   and executes the graph, returning whatever outputs are produced.

This split moves string processing off the critical path of running a
step. We also add a new method `Session::ReleaseCallable()` that makes
it possible to free the resources associated with a cached subgraph,
and could be useful for seldom-executed graphs such as initializers.

PiperOrigin-RevId: 188566635
This commit is contained in:
Derek Murray 2018-03-09 18:12:02 -08:00 committed by TensorFlower Gardener
parent 05aa4e58c8
commit 2426308fa5
6 changed files with 880 additions and 239 deletions

View File

@ -318,6 +318,7 @@ DirectSession::~DirectSession() {
for (auto& it : executors_) {
it.second.reset();
}
callables_.clear();
for (auto d : device_mgr_->ListDevices()) {
d->op_segment()->RemoveHold(session_handle_);
}
@ -409,16 +410,21 @@ Status DirectSession::Run(const NamedTensorList& inputs,
}
Status DirectSession::CreateDebuggerState(
const DebugOptions& debug_options, int64 session_run_index,
int64 executor_step_index, const std::vector<string>& input_names,
const std::vector<string>& output_names,
const std::vector<string>& target_names,
const CallableOptions& callable_options, int64 global_step,
int64 session_run_index, int64 executor_step_index,
std::unique_ptr<DebuggerStateInterface>* debugger_state) {
TF_RETURN_IF_ERROR(
DebuggerStateRegistry::CreateState(debug_options, debugger_state));
TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
callable_options.run_options().debug_options(), debugger_state));
std::vector<string> input_names(callable_options.feed().begin(),
callable_options.feed().end());
std::vector<string> output_names(callable_options.fetch().begin(),
callable_options.fetch().end());
std::vector<string> target_names(callable_options.target().begin(),
callable_options.target().end());
TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
debug_options.global_step(), session_run_index, executor_step_index,
input_names, output_names, target_names));
global_step, session_run_index, executor_step_index, input_names,
output_names, target_names));
return Status::OK();
}
@ -433,84 +439,23 @@ Status DirectSession::DecorateAndPublishGraphForDebug(
return Status::OK();
}
Status DirectSession::Run(const RunOptions& run_options,
const NamedTensorList& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs,
RunMetadata* run_metadata) {
TF_RETURN_IF_ERROR(CheckNotClosed());
direct_session_runs->GetCell()->IncrementBy(1);
{
mutex_lock l(graph_def_lock_);
if (!graph_created_) {
return errors::InvalidArgument(
"Session was not created with a graph before Run()!");
}
}
// Extract the inputs names for this run of the session.
std::vector<string> input_tensor_names;
input_tensor_names.reserve(inputs.size());
for (const auto& it : inputs) {
input_tensor_names.push_back(it.first);
}
if (run_options.inter_op_thread_pool() < 0 ||
run_options.inter_op_thread_pool() >= thread_pools_.size()) {
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
run_options.inter_op_thread_pool());
}
thread::ThreadPool* pool =
thread_pools_[run_options.inter_op_thread_pool()].first;
// Check if we already have an executor for these arguments.
ExecutorsAndKeys* executors_and_keys;
RunStateArgs run_state_args(run_options.debug_options());
Executor::Args args;
args.step_id = step_id_counter_.fetch_add(1);
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
target_nodes, &executors_and_keys,
&run_state_args));
Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
CallFrameInterface* call_frame,
ExecutorsAndKeys* executors_and_keys,
RunMetadata* run_metadata) {
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
std::unique_ptr<DebuggerStateInterface> debugger_state;
if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
TF_RETURN_IF_ERROR(CreateDebuggerState(
run_options.debug_options(), args.step_id, executor_step_count,
input_tensor_names, output_names, target_nodes, &debugger_state));
}
// Configure a call frame for the step, which we use to feed and
// fetch values to and from the executors.
FunctionCallFrame call_frame(executors_and_keys->input_types,
executors_and_keys->output_types);
gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
for (const auto& it : inputs) {
if (it.second.dtype() == DT_RESOURCE) {
Tensor tensor_from_handle;
TF_RETURN_IF_ERROR(
ResourceHandleToInputTensor(it.second, &tensor_from_handle));
feed_args[executors_and_keys->input_name_to_index[it.first]] =
tensor_from_handle;
} else {
feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
}
}
const Status s = call_frame.SetArgs(feed_args);
if (errors::IsInternal(s)) {
return errors::InvalidArgument(s.error_message());
} else if (!s.ok()) {
return s;
TF_RETURN_IF_ERROR(
CreateDebuggerState(executors_and_keys->callable_options,
run_options.debug_options().global_step(), step_id,
executor_step_count, &debugger_state));
}
// Create a run state and start execution.
RunState run_state(args.step_id, &devices_);
RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
CancellationManager step_cancellation_manager;
args.call_frame = &call_frame;
// Start parallel Executors.
const size_t num_executors = executors_and_keys->items.size();
@ -523,15 +468,15 @@ Status DirectSession::Run(const RunOptions& run_options,
run_state.executors_done.Notify();
});
Executor::Args args;
args.step_id = step_id;
args.call_frame = call_frame;
args.rendezvous = run_state.rendez;
CancellationManager step_cancellation_manager;
args.cancellation_manager = &step_cancellation_manager;
args.session_state = &session_state_;
args.tensor_store = &run_state.tensor_store;
args.step_container = &run_state.step_container;
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
args.sync_on_finish = sync_on_finish_;
const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
@ -569,6 +514,14 @@ Status DirectSession::Run(const RunOptions& run_options,
}
}
if (run_options.inter_op_thread_pool() < 0 ||
run_options.inter_op_thread_pool() >= thread_pools_.size()) {
run_state.executors_done.Notify();
delete barrier;
return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
run_options.inter_op_thread_pool());
}
// Register this step with session's cancellation manager, so that
// `Session::Close()` will cancel the step.
const CancellationToken cancellation_token =
@ -586,6 +539,9 @@ Status DirectSession::Run(const RunOptions& run_options,
return errors::Cancelled("Run call was cancelled");
}
thread::ThreadPool* pool =
thread_pools_[run_options.inter_op_thread_pool()].first;
Executor::Args::Runner default_runner = [this,
pool](Executor::Args::Closure c) {
SchedClosure(pool, std::move(c));
@ -628,6 +584,111 @@ Status DirectSession::Run(const RunOptions& run_options,
TF_RETURN_IF_ERROR(run_state.status);
}
// Save the output tensors of this run we choose to keep.
if (!run_state.tensor_store.empty()) {
TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
{executors_and_keys->callable_options.fetch().begin(),
executors_and_keys->callable_options.fetch().end()},
&session_state_));
}
if (args.stats_collector) {
args.stats_collector->Finalize();
}
// Build and return the cost model as instructed.
if (update_cost_model) {
// Build the cost model
std::unordered_map<string, const Graph*> device_to_graph;
for (const PerPartitionExecutorsAndLib& partition :
executors_and_keys->items) {
const Graph* graph = partition.graph;
const string device = partition.flib->device()->name();
device_to_graph[device] = graph;
}
mutex_lock l(executor_lock_);
args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
// annotate stats onto cost graph.
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
for (const auto& item : executors_and_keys->items) {
TF_RETURN_IF_ERROR(
cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
}
}
// If requested via RunOptions, output the partition graphs.
if (run_options.output_partition_graphs()) {
protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
run_metadata->mutable_partition_graphs();
for (const PerPartitionExecutorsAndLib& exec_and_lib :
executors_and_keys->items) {
GraphDef* partition_graph_def = partition_graph_defs->Add();
exec_and_lib.graph->ToGraphDef(partition_graph_def);
}
}
return Status::OK();
}
Status DirectSession::Run(const RunOptions& run_options,
const NamedTensorList& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs,
RunMetadata* run_metadata) {
TF_RETURN_IF_ERROR(CheckNotClosed());
TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
direct_session_runs->GetCell()->IncrementBy(1);
// Extract the inputs names for this run of the session.
std::vector<string> input_tensor_names;
input_tensor_names.reserve(inputs.size());
for (const auto& it : inputs) {
input_tensor_names.push_back(it.first);
}
// Check if we already have an executor for these arguments.
ExecutorsAndKeys* executors_and_keys;
RunStateArgs run_state_args(run_options.debug_options());
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
target_nodes, &executors_and_keys,
&run_state_args));
// Configure a call frame for the step, which we use to feed and
// fetch values to and from the executors.
FunctionCallFrame call_frame(executors_and_keys->input_types,
executors_and_keys->output_types);
gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
for (const auto& it : inputs) {
if (it.second.dtype() == DT_RESOURCE) {
Tensor tensor_from_handle;
TF_RETURN_IF_ERROR(
ResourceHandleToInputTensor(it.second, &tensor_from_handle));
feed_args[executors_and_keys->input_name_to_index[it.first]] =
tensor_from_handle;
} else {
feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
}
}
const Status s = call_frame.SetArgs(feed_args);
if (errors::IsInternal(s)) {
return errors::InvalidArgument(s.error_message());
} else if (!s.ok()) {
return s;
}
const int64 step_id = step_id_counter_.fetch_add(1);
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(step_id, run_state_args.handle);
}
TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
executors_and_keys, run_metadata));
// Receive outputs.
if (outputs) {
std::vector<Tensor> sorted_outputs;
@ -667,45 +728,6 @@ Status DirectSession::Run(const RunOptions& run_options,
}
}
// Save the output tensors of this run we choose to keep.
TF_RETURN_IF_ERROR(
run_state.tensor_store.SaveTensors(output_names, &session_state_));
if (args.stats_collector) {
args.stats_collector->Finalize();
}
// Build and return the cost model as instructed.
mutex_lock l(executor_lock_);
if (update_cost_model) {
// Build the cost model
std::unordered_map<string, const Graph*> device_to_graph;
for (const PerPartitionExecutorsAndLib& partition :
executors_and_keys->items) {
const Graph* graph = partition.graph;
const string device = partition.flib->device()->name();
device_to_graph[device] = graph;
}
args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
// annotate stats onto cost graph.
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
for (const auto& item : executors_and_keys->items) {
TF_RETURN_IF_ERROR(
cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
}
}
// If requested via RunOptions, output the partition graphs.
if (run_options.output_partition_graphs()) {
protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
run_metadata->mutable_partition_graphs();
for (const PerPartitionExecutorsAndLib& exec_and_lib :
executors_and_keys->items) {
GraphDef* partition_graph_def = partition_graph_defs->Add();
exec_and_lib.graph->ToGraphDef(partition_graph_def);
}
}
return Status::OK();
}
@ -714,13 +736,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
const std::vector<string>& target_nodes,
string* handle) {
TF_RETURN_IF_ERROR(CheckNotClosed());
{
mutex_lock l(graph_def_lock_);
if (!graph_created_) {
return errors::InvalidArgument(
"Session was not created with a graph before PRunSetup()!");
}
}
TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
// RunOptions is not available in PRunSetup, so use thread pool 0.
thread::ThreadPool* pool = thread_pools_[0].first;
@ -1061,92 +1077,31 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
return Status::OK();
}
Status DirectSession::GetOrCreateExecutors(
gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
Status DirectSession::CreateExecutors(
const CallableOptions& callable_options,
std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
std::unique_ptr<FunctionInfo>* out_func_info,
RunStateArgs* run_state_args) {
int64 handle_name_counter_value = -1;
if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
handle_name_counter_value = handle_name_counter_.fetch_add(1);
}
string debug_tensor_watches_summary;
if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
debug_tensor_watches_summary = SummarizeDebugTensorWatches(
run_state_args->debug_options.debug_tensor_watch_opts());
}
// Fast lookup path, no sorting.
const string key = strings::StrCat(
str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
"/", debug_tensor_watches_summary);
// Set the handle, if it's needed to log memory or for partial run.
if (handle_name_counter_value >= 0) {
run_state_args->handle =
strings::StrCat(key, ";", handle_name_counter_value);
}
// See if we already have the executors for this run.
{
mutex_lock l(executor_lock_); // could use reader lock
auto it = executors_.find(key);
if (it != executors_.end()) {
*executors_and_keys = it->second.get();
return Status::OK();
}
}
// Slow lookup path, the unsorted key missed the cache.
// Sort the inputs and outputs, and look up with the sorted key in case an
// earlier call used a different order of inputs and outputs.
//
// We could consider some other signature instead of sorting that
// preserves the same property to avoid the sort in the future.
std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
std::sort(inputs_sorted.begin(), inputs_sorted.end());
std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
std::sort(outputs_sorted.begin(), outputs_sorted.end());
std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
std::sort(tn_sorted.begin(), tn_sorted.end());
const string sorted_key = strings::StrCat(
str_util::Join(inputs_sorted, ","), "->",
str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
"/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
// Set the handle, if its needed to log memory or for partial run.
if (handle_name_counter_value >= 0) {
run_state_args->handle =
strings::StrCat(sorted_key, ";", handle_name_counter_value);
}
// See if we already have the executors for this run.
{
mutex_lock l(executor_lock_);
auto it = executors_.find(sorted_key);
if (it != executors_.end()) {
*executors_and_keys = it->second.get();
// Insert this under the original key.
executors_.emplace(key, it->second);
return Status::OK();
}
}
// Nothing found, so create the executors and store in the cache.
BuildGraphOptions options;
options.feed_endpoints = inputs_sorted;
options.fetch_endpoints = outputs_sorted;
options.target_nodes = tn_sorted;
options.feed_endpoints = std::vector<string>(callable_options.feed().begin(),
callable_options.feed().end());
options.fetch_endpoints = std::vector<string>(
callable_options.fetch().begin(), callable_options.fetch().end());
options.target_nodes = std::vector<string>(callable_options.target().begin(),
callable_options.target().end());
options.use_function_convention = !run_state_args->is_partial_run;
if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
options.debug_options = run_state_args->debug_options;
if (!callable_options.run_options()
.debug_options()
.debug_tensor_watch_opts()
.empty()) {
options.debug_options = callable_options.run_options().debug_options();
}
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
ek->callable_options = callable_options;
// The executor_lock_ is intentionally released while executor is
// being created.
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
run_state_args, &ek->input_types,
@ -1155,11 +1110,11 @@ Status DirectSession::GetOrCreateExecutors(
if (run_state_args->is_partial_run) {
ek->graph = std::move(run_state_args->graph);
std::unordered_set<StringPiece, StringPieceHasher> names;
for (const string& input : inputs) {
for (const string& input : callable_options.feed()) {
TensorId id(ParseTensorName(input));
names.emplace(id.first);
}
for (const string& output : outputs) {
for (const string& output : callable_options.fetch()) {
TensorId id(ParseTensorName(output));
names.emplace(id.first);
}
@ -1260,12 +1215,12 @@ Status DirectSession::GetOrCreateExecutors(
// For regular `Run()`, we use the function calling convention, and so
// maintain a mapping from input/output names to
// argument/return-value ordinal index.
for (size_t i = 0; i < inputs_sorted.size(); ++i) {
const string& input = inputs_sorted[i];
for (int i = 0; i < callable_options.feed().size(); ++i) {
const string& input = callable_options.feed(i);
ek->input_name_to_index[input] = i;
}
for (size_t i = 0; i < outputs_sorted.size(); ++i) {
const string& output = outputs_sorted[i];
for (int i = 0; i < callable_options.fetch().size(); ++i) {
const string& output = callable_options.fetch(i);
ek->output_name_to_index[output] = i;
}
} else {
@ -1274,26 +1229,123 @@ Status DirectSession::GetOrCreateExecutors(
//
// We always use the first device as the device name portion of the
// key, even if we're feeding another graph.
for (size_t i = 0; i < inputs_sorted.size(); ++i) {
const string& input = inputs_sorted[i];
for (int i = 0; i < callable_options.feed().size(); ++i) {
const string& input = callable_options.feed(i);
ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
}
for (size_t i = 0; i < outputs_sorted.size(); ++i) {
const string& output = outputs_sorted[i];
for (int i = 0; i < callable_options.fetch().size(); ++i) {
const string& output = callable_options.fetch(i);
ek->output_name_to_rendezvous_key[output] =
GetRendezvousKey(output, device_set_.client_device()->attributes(),
FrameAndIter(0, 0));
}
}
*out_executors_and_keys = std::move(ek);
*out_func_info = std::move(func_info);
return Status::OK();
}
Status DirectSession::GetOrCreateExecutors(
gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
RunStateArgs* run_state_args) {
int64 handle_name_counter_value = -1;
if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
handle_name_counter_value = handle_name_counter_.fetch_add(1);
}
string debug_tensor_watches_summary;
if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
debug_tensor_watches_summary = SummarizeDebugTensorWatches(
run_state_args->debug_options.debug_tensor_watch_opts());
}
// Fast lookup path, no sorting.
const string key = strings::StrCat(
str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
"/", debug_tensor_watches_summary);
// Set the handle, if it's needed to log memory or for partial run.
if (handle_name_counter_value >= 0) {
run_state_args->handle =
strings::StrCat(key, ";", handle_name_counter_value);
}
// See if we already have the executors for this run.
{
mutex_lock l(executor_lock_); // could use reader lock
auto it = executors_.find(key);
if (it != executors_.end()) {
*executors_and_keys = it->second.get();
return Status::OK();
}
}
// Slow lookup path, the unsorted key missed the cache.
// Sort the inputs and outputs, and look up with the sorted key in case an
// earlier call used a different order of inputs and outputs.
//
// We could consider some other signature instead of sorting that
// preserves the same property to avoid the sort in the future.
std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
std::sort(inputs_sorted.begin(), inputs_sorted.end());
std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
std::sort(outputs_sorted.begin(), outputs_sorted.end());
std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
std::sort(tn_sorted.begin(), tn_sorted.end());
const string sorted_key = strings::StrCat(
str_util::Join(inputs_sorted, ","), "->",
str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
"/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
// Set the handle, if its needed to log memory or for partial run.
if (handle_name_counter_value >= 0) {
run_state_args->handle =
strings::StrCat(sorted_key, ";", handle_name_counter_value);
}
// See if we already have the executors for this run.
{
mutex_lock l(executor_lock_);
auto it = executors_.find(sorted_key);
if (it != executors_.end()) {
*executors_and_keys = it->second.get();
// Insert this under the original key.
executors_.emplace(key, it->second);
return Status::OK();
}
}
// Nothing found, so create the executors and store in the cache.
// The executor_lock_ is intentionally released while executors are
// being created.
CallableOptions callable_options;
for (const string& input : inputs_sorted) {
callable_options.add_feed(input);
}
for (const string& output : outputs_sorted) {
callable_options.add_fetch(output);
}
for (const string& target : tn_sorted) {
callable_options.add_target(target);
}
*callable_options.mutable_run_options()->mutable_debug_options() =
run_state_args->debug_options;
std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info;
TF_RETURN_IF_ERROR(
CreateExecutors(callable_options, &ek, &func_info, run_state_args));
// Reacquire the lock, try to insert into the map.
mutex_lock l(executor_lock_);
functions_.push_back(std::move(func_info));
// Another thread may have created the entry before us, in which case we will
// reuse the already created one.
auto insert_result = executors_.emplace(sorted_key, ek);
auto insert_result = executors_.emplace(
sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
// Insert the value under the original key, so the fast path lookup will work
// if the user uses the same order of inputs, outputs, and targets again.
executors_.emplace(key, insert_result.first->second);
@ -1562,4 +1614,156 @@ void DirectSession::WaitForNotification(RunState* run_state,
return Status::OK();
}
Status DirectSession::MakeCallable(const CallableOptions& callable_options,
CallableHandle* out_handle) {
TF_RETURN_IF_ERROR(CheckNotClosed());
TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
if (!callable_options.run_options()
.debug_options()
.debug_tensor_watch_opts()
.empty()) {
return errors::Unimplemented(
"Debug options are not currently supported via the C++ MakeCallable "
"interface.");
}
std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info;
RunStateArgs run_state_args(callable_options.run_options().debug_options());
TF_RETURN_IF_ERROR(
CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
{
mutex_lock l(callables_lock_);
*out_handle = next_callable_handle_++;
callables_[*out_handle] = {std::move(ek), std::move(func_info)};
}
return Status::OK();
}
class DirectSession::RunCallableCallFrame : public CallFrameInterface {
public:
RunCallableCallFrame(DirectSession* session,
ExecutorsAndKeys* executors_and_keys,
const std::vector<Tensor>* feed_tensors,
std::vector<Tensor>* fetch_tensors)
: session_(session),
executors_and_keys_(executors_and_keys),
feed_tensors_(feed_tensors),
fetch_tensors_(fetch_tensors) {}
size_t num_args() const override {
return executors_and_keys_->input_types.size();
}
size_t num_retvals() const override {
return executors_and_keys_->output_types.size();
}
Status GetArg(int index, Tensor* val) const override {
if (index > feed_tensors_->size()) {
return errors::Internal("Args index out of bounds: ", index);
} else if (executors_and_keys_->input_types[index] == DT_RESOURCE) {
TF_RETURN_IF_ERROR(
session_->ResourceHandleToInputTensor((*feed_tensors_)[index], val));
} else {
*val = (*feed_tensors_)[index];
}
return Status::OK();
}
Status SetRetval(int index, const Tensor& val) override {
if (index > fetch_tensors_->size()) {
return errors::Internal("RetVal index out of bounds: ", index);
}
(*fetch_tensors_)[index] = val;
return Status::OK();
}
private:
DirectSession* const session_; // Not owned.
ExecutorsAndKeys* const executors_and_keys_; // Not owned.
const std::vector<Tensor>* const feed_tensors_; // Not owned.
std::vector<Tensor>* const fetch_tensors_; // Not owned.
};
::tensorflow::Status DirectSession::RunCallable(
CallableHandle handle, const std::vector<Tensor>& feed_tensors,
std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
TF_RETURN_IF_ERROR(CheckNotClosed());
TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
direct_session_runs->GetCell()->IncrementBy(1);
// Check if we already have an executor for these arguments.
std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
const int64 step_id = step_id_counter_.fetch_add(1);
{
tf_shared_lock l(callables_lock_);
if (handle >= next_callable_handle_) {
return errors::InvalidArgument("No such callable handle: ", handle);
}
executors_and_keys = callables_[handle].executors_and_keys;
}
if (!executors_and_keys) {
return errors::InvalidArgument(
"Attempted to run callable after handle was released: ", handle);
}
// NOTE(mrry): Debug options are not currently supported in the
// callable interface.
DebugOptions debug_options;
RunStateArgs run_state_args(debug_options);
// Configure a call frame for the step, which we use to feed and
// fetch values to and from the executors.
if (feed_tensors.size() != executors_and_keys->input_types.size()) {
return errors::InvalidArgument(
"Expected ", executors_and_keys->input_types.size(),
" feed tensors, but got ", feed_tensors.size());
}
if (fetch_tensors != nullptr) {
fetch_tensors->resize(executors_and_keys->output_types.size());
} else if (!executors_and_keys->output_types.empty()) {
return errors::InvalidArgument(
"`fetch_tensors` must be provided when the callable has one or more "
"outputs.");
}
// A specialized CallFrame implementation that takes advantage of the
// optimized RunCallable interface.
RunCallableCallFrame call_frame(this, executors_and_keys.get(), &feed_tensors,
fetch_tensors);
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(step_id, run_state_args.handle);
}
TF_RETURN_IF_ERROR(
RunInternal(step_id, executors_and_keys->callable_options.run_options(),
&call_frame, executors_and_keys.get(), run_metadata));
return Status::OK();
}
::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
mutex_lock l(callables_lock_);
if (handle >= next_callable_handle_) {
return errors::InvalidArgument("No such callable handle: ", handle);
}
callables_.erase(handle);
return Status::OK();
}
DirectSession::Callable::~Callable() {
// We must delete the fields in this order, because the destructor
// of `executors_and_keys` will call into an object owned by
// `function_info` (in particular, when deleting a kernel, it relies
// on the `FunctionLibraryRuntime` to know if the kernel is stateful
// or not).
executors_and_keys.reset();
function_info.reset();
}
} // namespace tensorflow

View File

@ -107,6 +107,14 @@ class DirectSession : public Session {
cost_model_manager_.ExportCostModels(cost_models);
}
::tensorflow::Status MakeCallable(const CallableOptions& callable_options,
CallableHandle* out_handle) override;
::tensorflow::Status RunCallable(CallableHandle handle,
const std::vector<Tensor>& feed_tensors,
std::vector<Tensor>* fetch_tensors,
RunMetadata* run_metadata) override;
::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
private:
// We create one executor and its dependent library runtime for
// every partition.
@ -139,6 +147,8 @@ class DirectSession : public Session {
DataTypeVector input_types;
DataTypeVector output_types;
CallableOptions callable_options;
};
// A FunctionInfo object is created for every unique set of feeds/fetches.
@ -206,6 +216,14 @@ class DirectSession : public Session {
gtl::ArraySlice<string> target_nodes,
ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
// Creates a set of executors to run the subgraph defined by
// `callable_options`.
::tensorflow::Status CreateExecutors(
const CallableOptions& callable_options,
std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
std::unique_ptr<FunctionInfo>* out_func_info,
RunStateArgs* run_state_args);
// Creates several graphs given the existing graph_def_ and the
// input feeds and fetches, given 'devices'. The graphs share a common
// function library 'flib_def'.
@ -216,6 +234,11 @@ class DirectSession : public Session {
RunStateArgs* run_state_args, DataTypeVector* input_types,
DataTypeVector* output_types);
::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
CallFrameInterface* call_frame,
ExecutorsAndKeys* executors_and_keys,
RunMetadata* run_metadata);
::tensorflow::Status ExtendLocked(const GraphDef& graph)
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
@ -257,11 +280,18 @@ class DirectSession : public Session {
return ::tensorflow::Status::OK();
}
::tensorflow::Status CheckGraphCreated(const char* method) {
mutex_lock l(graph_def_lock_);
if (!graph_created_) {
return errors::InvalidArgument(
"Session was not created with a graph before ", method, "!");
}
return ::tensorflow::Status::OK();
}
::tensorflow::Status CreateDebuggerState(
const DebugOptions& debug_options, int64 session_run_index,
int64 executor_step_index, const std::vector<string>& input_names,
const std::vector<string>& output_names,
const std::vector<string>& target_names,
const CallableOptions& options, int64 global_step,
int64 session_run_index, int64 executor_step_index,
std::unique_ptr<DebuggerStateInterface>* debugger_state);
::tensorflow::Status DecorateAndPublishGraphForDebug(
@ -303,6 +333,16 @@ class DirectSession : public Session {
std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
GUARDED_BY(executor_lock_);
class RunCallableCallFrame;
struct Callable {
std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
std::shared_ptr<FunctionInfo> function_info;
~Callable();
};
mutex callables_lock_;
int64 next_callable_handle_ GUARDED_BY(callables_lock_) = 0;
std::unordered_map<int64, Callable> callables_ GUARDED_BY(callables_lock_);
// Holds mappings from handle to partial run state.
std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
GUARDED_BY(executor_lock_);

View File

@ -49,6 +49,22 @@ limitations under the License.
namespace tensorflow {
namespace {
CallableOptions MakeCallableOptions(gtl::ArraySlice<string> feeds,
gtl::ArraySlice<string> fetches,
gtl::ArraySlice<string> targets) {
CallableOptions ret;
for (const string& feed : feeds) {
ret.add_feed(feed);
}
for (const string& fetch : fetches) {
ret.add_fetch(fetch);
}
for (const string& target : targets) {
ret.add_target(target);
}
return ret;
}
std::unique_ptr<Session> CreateSession() {
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
@ -111,6 +127,53 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
}
TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
Initialize({3, 2, -1, 0});
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
std::vector<std::pair<string, Tensor>> inputs;
// Run the test twice to ensure that the Make/Run/Release cycle is hermetic.
for (int i = 0; i < 2; ++i) {
// Request two targets: one fetch output and one non-fetched output.
Session::CallableHandle handle;
TF_ASSERT_OK(session->MakeCallable(
MakeCallableOptions({}, {y_ + ":0"}, {y_neg_}), &handle));
for (int i = 0; i < 2; ++i) {
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
ASSERT_EQ(1, outputs.size());
// The first output should be initialized and have the correct
// output.
auto mat = outputs[0].matrix<float>();
ASSERT_TRUE(outputs[0].IsInitialized());
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
}
Status s = session->RunCallable(handle, {}, nullptr, nullptr);
EXPECT_TRUE(errors::IsInvalidArgument(s));
EXPECT_TRUE(StringPiece(s.error_message())
.contains("`fetch_tensors` must be provided"));
TF_ASSERT_OK(session->ReleaseCallable(handle));
std::vector<Tensor> outputs;
s = session->RunCallable(handle, {}, &outputs, nullptr);
EXPECT_TRUE(errors::IsInvalidArgument(s));
EXPECT_TRUE(
StringPiece(s.error_message())
.contains("Attempted to run callable after handle was released"));
s = session->RunCallable(handle + 1, {}, &outputs, nullptr);
EXPECT_TRUE(errors::IsInvalidArgument(s));
EXPECT_TRUE(
StringPiece(s.error_message()).contains("No such callable handle"));
}
}
TEST_F(DirectSessionMinusAXTest, TestFeed) {
Initialize({1, 2, 3, 4});
auto session = CreateSession();
@ -140,6 +203,39 @@ TEST_F(DirectSessionMinusAXTest, TestFeed) {
EXPECT_FLOAT_EQ(39.0, mat(1, 0));
}
TEST_F(DirectSessionMinusAXTest, TestFeed_Callable) {
Initialize({1, 2, 3, 4});
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
// Fill in the input and ask for the output
//
// Note that the input being fed is on the second device.
CallableOptions callable_options;
callable_options.add_feed(x_);
callable_options.add_fetch(y_ + ":0");
Session::CallableHandle handle;
TF_ASSERT_OK(session->MakeCallable(MakeCallableOptions({x_}, {y_ + ":0"}, {}),
&handle));
Tensor t(DT_FLOAT, TensorShape({2, 1}));
t.matrix<float>()(0, 0) = 5;
t.matrix<float>()(1, 0) = 6;
std::vector<Tensor> inputs = {t};
std::vector<Tensor> outputs;
// Run the callable
TF_ASSERT_OK(session->RunCallable(handle, inputs, &outputs, nullptr));
ASSERT_EQ(1, outputs.size());
auto mat = outputs[0].matrix<float>();
// Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
EXPECT_FLOAT_EQ(17.0, mat(0, 0));
EXPECT_FLOAT_EQ(39.0, mat(1, 0));
}
TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
Initialize({1, 2, 3, 4});
auto session = CreateSession();
@ -172,6 +268,39 @@ TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
delete tp;
}
TEST_F(DirectSessionMinusAXTest, TestConcurrency_Callable) {
Initialize({1, 2, 3, 4});
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
// Fill in the input and ask for the output
thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
Session::CallableHandle handle;
TF_ASSERT_OK(
session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle));
// Run the callable 1000 times in 4 different threads concurrently.
auto fn = [&session, handle]() {
for (int i = 0; i < 1000; ++i) {
std::vector<Tensor> outputs;
// Run the graph
TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
ASSERT_EQ(1, outputs.size());
auto mat = outputs[0].matrix<float>();
EXPECT_FLOAT_EQ(3.0, mat(0, 0));
}
};
for (int i = 0; i < 4; ++i) {
tp->Schedule(fn);
}
// Wait for the functions to finish.
delete tp;
}
TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) {
Initialize({1, 2, 3, 4});
@ -297,6 +426,38 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
}
TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
Initialize({3, 2, -1, 0});
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
// Request two targets: one fetch output and one non-fetched output.
Session::CallableHandle handle;
CallableOptions callable_options =
MakeCallableOptions({}, {y_ + ":0"}, {y_neg_});
callable_options.mutable_run_options()->set_trace_level(
RunOptions::FULL_TRACE);
TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
RunMetadata run_metadata;
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0);
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, &run_metadata));
ASSERT_EQ(1, outputs.size());
// The first output should be initialized and have the correct
// output.
auto mat = outputs[0].matrix<float>();
ASSERT_TRUE(outputs[0].IsInitialized());
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
// Checks RunMetadata is well-formed
ASSERT_TRUE(run_metadata.has_step_stats());
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
}
TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
GraphDef def;
Graph g(OpRegistry::Global());
@ -409,6 +570,89 @@ TEST(DirectSessionTest, MultipleFeedTest) {
EXPECT_TRUE(StringPiece(s.error_message()).contains("fed more than once"));
}
TEST(DirectSessionTest, MultipleFeedTest_Callable) {
GraphDef def;
Graph g(OpRegistry::Global());
Tensor first_value(DT_FLOAT, TensorShape({}));
first_value.scalar<float>()() = 1.0;
Node* first_const = test::graph::Constant(&g, first_value);
Node* first_identity = test::graph::Identity(&g, first_const);
Tensor second_value(DT_FLOAT, TensorShape({}));
second_value.scalar<float>()() = 2.0;
Node* second_const = test::graph::Constant(&g, second_value);
Node* second_identity = test::graph::Identity(&g, second_const);
test::graph::ToGraphDef(&g, &def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
Session::CallableHandle handle;
std::vector<Tensor> outputs;
// Fetch without feeding.
TF_ASSERT_OK(session->MakeCallable(
MakeCallableOptions(
{}, {first_identity->name() + ":0", second_identity->name() + ":0"},
{}),
&handle));
TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
ASSERT_EQ(2, outputs.size());
ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
TF_ASSERT_OK(session->MakeCallable(
MakeCallableOptions(
{}, {second_identity->name() + ":0", first_identity->name() + ":0"},
{}),
&handle));
TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
ASSERT_EQ(2, outputs.size());
ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
Tensor value_11(DT_FLOAT, TensorShape({}));
value_11.scalar<float>()() = 11.0;
Tensor value_22(DT_FLOAT, TensorShape({}));
value_22.scalar<float>()() = 22.0;
// Feed [first_const, second_const]
TF_ASSERT_OK(session->MakeCallable(
MakeCallableOptions(
{first_const->name(), second_const->name()},
{first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
&handle));
TF_ASSERT_OK(
session->RunCallable(handle, {value_11, value_22}, &outputs, nullptr));
ASSERT_EQ(2, outputs.size());
ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
// Feed [second_const, first_const]
TF_ASSERT_OK(session->MakeCallable(
MakeCallableOptions(
{second_const->name(), first_const->name()},
{first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
&handle));
TF_ASSERT_OK(
session->RunCallable(handle, {value_22, value_11}, &outputs, nullptr));
ASSERT_EQ(2, outputs.size());
ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
// Feed [first_const, first_const]
Status s = session->MakeCallable(
MakeCallableOptions(
{first_const->name(), first_const->name()},
{first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
&handle);
EXPECT_TRUE(errors::IsInvalidArgument(s));
EXPECT_TRUE(StringPiece(s.error_message()).contains("fed more than once"));
}
TEST(DirectSessionTest, FetchMultipleTimes) {
Graph g(OpRegistry::Global());
Tensor seven_tensor(DT_INT32, TensorShape());
@ -695,6 +939,59 @@ TEST(DirectSessionTest, RunHandleTest) {
ASSERT_TRUE(s.ok());
}
TEST(DirectSessionTest, RunHandleTest_Callable) {
GraphDef def;
Graph g(OpRegistry::Global());
Tensor value0(DT_FLOAT, TensorShape({}));
value0.scalar<float>()() = 1.0;
Node* const0 = test::graph::Constant(&g, value0);
Node* identity0 = test::graph::Identity(&g, const0);
Tensor value1(DT_FLOAT, TensorShape({}));
value1.scalar<float>()() = 2.0;
Node* const1 = test::graph::Constant(&g, value1);
Node* node3 = test::graph::Add(&g, identity0, const1);
Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
Tensor value2(DT_STRING, TensorShape({}));
Node* const2 = test::graph::Constant(&g, value2);
Node* node5 = test::graph::GetSessionTensor(&g, const2);
Node* node6 = test::graph::Add(&g, node5, const1);
Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
test::graph::ToGraphDef(&g, &def);
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
// First run call: Create a handle.
std::vector<Tensor> outputs;
Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
Tensor string_handle(DT_STRING, {});
string_handle.flat<string>().setConstant(resource_handle.name());
// Second run call: Use a handle.
std::vector<Tensor> outputs1;
s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
{}, &outputs1);
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs1.size());
ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
// Third run call: Delete a handle.
std::vector<Tensor> outputs2;
s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
&outputs2);
ASSERT_TRUE(s.ok());
}
TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) {
Graph graph(OpRegistry::Global());
@ -1109,6 +1406,11 @@ TEST(DirectSessionTest, TestDirectSessionRunClose) {
EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
outputs.clear();
// Make a callable handle before closing the session.
Session::CallableHandle handle;
TF_ASSERT_OK(session->MakeCallable(
MakeCallableOptions({}, {}, {var_assign->name()}), &handle));
// Close the session.
TF_ASSERT_OK(session->Close());
@ -1116,6 +1418,10 @@ TEST(DirectSessionTest, TestDirectSessionRunClose) {
Status s = session->Run({} /* inputs */, {},
{var_assign->name()} /* target_nodes */, nullptr);
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
// Run the read as a callable to verify that we get the same error.
s = session->RunCallable(handle, {}, {}, nullptr);
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
}
TEST(DirectSessionTest, TestDirectSessionPRunClose) {
@ -1217,7 +1523,8 @@ TEST(DirectSessionTest, LocalDeviceManager) {
// A simple benchmark for the overhead of `DirectSession::Run()` calls
// with varying numbers of feeds/fetches.
void FeedFetchBenchmarkHelper(int iters, int num_feeds) {
void FeedFetchBenchmarkHelper(int iters, int num_feeds,
bool use_make_callable) {
testing::StopTiming();
Tensor value(DT_FLOAT, TensorShape());
@ -1253,29 +1560,55 @@ void FeedFetchBenchmarkHelper(int iters, int num_feeds) {
SessionOptions opts;
std::unique_ptr<Session> session(NewSession(opts));
TF_CHECK_OK(session->Create(gd));
{
// NOTE(mrry): Ignore the first run, which will incur the graph
// partitioning/pruning overhead and skew the results.
//
// Note that we should also optimize and monitor the overhead on
// the first run, which will impact application startup times, but
// that is not the object of study in this benchmark.
std::vector<Tensor> output_values;
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
if (use_make_callable) {
Session::CallableHandle handle;
CallableOptions callable_options;
std::vector<Tensor> input_tensors;
for (const auto& input : inputs) {
callable_options.add_feed(input.first);
input_tensors.push_back(input.second);
}
for (const string& output : outputs) {
callable_options.add_fetch(output);
}
TF_CHECK_OK(session->MakeCallable(callable_options, &handle));
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
std::vector<Tensor> output_values;
TF_CHECK_OK(
session->RunCallable(handle, input_tensors, &output_values, nullptr));
}
testing::StopTiming();
} else {
{
// NOTE(mrry): Ignore the first run, which will incur the graph
// partitioning/pruning overhead and skew the results.
//
// Note that we should also optimize and monitor the overhead on
// the first run, which will impact application startup times, but
// that is not the object of study in this benchmark.
std::vector<Tensor> output_values;
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
}
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
std::vector<Tensor> output_values;
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
}
testing::StopTiming();
}
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
std::vector<Tensor> output_values;
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
}
testing::StopTiming();
}
void BM_FeedFetch(int iters, int num_feeds) {
FeedFetchBenchmarkHelper(iters, num_feeds);
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ false);
}
void BM_FeedFetchCallable(int iters, int num_feeds) {
FeedFetchBenchmarkHelper(iters, num_feeds, /* use_make_callable */ true);
}
BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
} // namespace
} // namespace tensorflow

View File

@ -74,6 +74,12 @@ class TensorStore {
Status SaveTensors(const std::vector<string>& output_names,
SessionState* session_state);
// Returns true if no tensors have been added to this store.
bool empty() {
mutex_lock l(lock_);
return tensors_.empty();
}
private:
mutex lock_;

View File

@ -410,3 +410,26 @@ message RunMetadata {
// Graphs of the partitions executed by executors.
repeated GraphDef partition_graphs = 3;
}
// Defines a subgraph in another `GraphDef` as a set of feed points and nodes
// to be fetched or executed.
//
// Compare with the arguments to `Session::Run()`.
message CallableOptions {
// Tensors to be fed in the callable. Each feed is the name of a tensor.
repeated string feed = 1;
// Fetches. A list of tensor names. The caller of the callable expects a
// tensor to be returned for each fetch[i] (see RunStepResponse.tensor). The
// order of specified fetches does not change the execution order.
repeated string fetch = 2;
// Target Nodes. A list of node names. The named nodes will be run by the
// callable but their outputs will not be returned.
repeated string target = 3;
// Options that will be applied to each run.
RunOptions run_options = 4;
// Next: 5
}

View File

@ -195,6 +195,41 @@ class Session {
return errors::Unimplemented(
"LocalDeviceManager is not supported for this session.");
}
/// \brief A handle to a subgraph, created with `Session::MakeCallable()`.
typedef int64 CallableHandle;
/// \brief Creates a `handle` for invoking the subgraph defined by
/// `callable_options`.
/// NOTE: This API is still experimental and may change.
virtual Status MakeCallable(const CallableOptions& callable_options,
CallableHandle* out_handle) {
return errors::Unimplemented(
"MakeCallable is not supported for this session.");
}
/// \brief Invokes the subgraph named by `handle` with the given options and
/// input tensors.
///
/// The order of tensors in `feed_tensors` must and `fetch_tensors` will
/// match the order of names in `CallableOptions::feed()` and
/// `CallableOptions::fetch()` when this subgraph was created.
/// NOTE: This API is still experimental and may change.
virtual Status RunCallable(CallableHandle handle,
const std::vector<Tensor>& feed_tensors,
std::vector<Tensor>* fetch_tensors,
RunMetadata* run_metadata) {
return errors::Unimplemented(
"RunCallable is not supported for this session.");
}
/// \brief Releases resources associated with the given `handle` in this
/// session.
/// NOTE: This API is still experimental and may change.
virtual Status ReleaseCallable(CallableHandle handle) {
return errors::Unimplemented(
"ReleaseCallable is not supported for this session.");
}
};
/// \brief Create a new session with the given options.