Adds RunWithOpts() that takes a per-step RunOptions and RunOutputs to
C++ Session API. Use it for optionally turning tracing on for a step and returning profiling info collected via StepStats. For DirectSession only. Example usage: RunOptions run_options; run_options.set_trace_level(RunOptions::FULL_TRACE); RunOutputs run_outputs; ASSERT_TRUE(!run_outputs.has_step_stats()); Status s = session->RunWithOpts(run_options, inputs, output_names, target_nodes, &outputs, &run_outputs); ASSERT_TRUE(run_outputs.has_step_stats()); Change: 115693287
This commit is contained in:
parent
c36725ad79
commit
96118731cb
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/session_factory.h"
|
||||
#include "tensorflow/core/common_runtime/simple_placer.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
@ -251,6 +252,16 @@ Status DirectSession::Run(const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
std::vector<Tensor>* outputs) {
|
||||
return RunWithOpts(kEmptyRunOptions, inputs, output_names, target_nodes,
|
||||
outputs, &kEmptyRunOutputs);
|
||||
}
|
||||
|
||||
Status DirectSession::RunWithOpts(const RunOptions& run_options,
|
||||
const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
RunOutputs* run_outputs) {
|
||||
{
|
||||
mutex_lock l(graph_def_lock_);
|
||||
if (!graph_created_) {
|
||||
@ -296,11 +307,21 @@ Status DirectSession::Run(const NamedTensorList& inputs,
|
||||
VLOG(1) << "Step " << args.step_id << " is for handle "
|
||||
<< run_state_args.handle;
|
||||
|
||||
if (run_options.trace_level() == RunOptions::FULL_TRACE) {
|
||||
args.stats_collector =
|
||||
new StepStatsCollector(run_outputs->mutable_step_stats());
|
||||
}
|
||||
|
||||
for (const auto& item : executors_and_keys->items) {
|
||||
item.executor->RunAsync(args, barrier->Get());
|
||||
}
|
||||
|
||||
run_state.executors_done.WaitForNotification();
|
||||
|
||||
if (run_options.trace_level() == RunOptions::FULL_TRACE) {
|
||||
delete args.stats_collector;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(run_state.status);
|
||||
|
||||
// Receive outputs.
|
||||
|
@ -60,6 +60,14 @@ class DirectSession : public Session {
|
||||
const std::vector<string>& target_nodes,
|
||||
std::vector<Tensor>* outputs) override;
|
||||
|
||||
// NOTE: Experimental and subject to change.
|
||||
::tensorflow::Status RunWithOpts(const RunOptions& run_options,
|
||||
const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
RunOutputs* run_outputs) override;
|
||||
|
||||
// NOTE: PRunSetup and PRun are added to support partial execution. This
|
||||
// feature is experimental and subject to change.
|
||||
::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
|
||||
@ -147,6 +155,9 @@ class DirectSession : public Session {
|
||||
Graph* graph = nullptr;
|
||||
};
|
||||
|
||||
const RunOptions kEmptyRunOptions = RunOptions();
|
||||
RunOutputs kEmptyRunOutputs = RunOutputs();
|
||||
|
||||
// Retrieves an already existing set of executors to run 'inputs' and
|
||||
// 'outputs', or creates and caches them for future use.
|
||||
::tensorflow::Status GetOrCreateExecutors(
|
||||
|
@ -255,6 +255,40 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
|
||||
TF_ASSERT_OK(session->Run(inputs, output_names, {}, &outputs));
|
||||
}
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
|
||||
Initialize({3, 2, -1, 0});
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
|
||||
// Request two targets: one fetch output and one non-fetched output.
|
||||
std::vector<string> output_names = {y_ + ":0"};
|
||||
std::vector<string> target_nodes = {y_neg_};
|
||||
std::vector<Tensor> outputs;
|
||||
|
||||
// Prepares RunOptions and RunOutputs
|
||||
RunOptions run_options;
|
||||
run_options.set_trace_level(RunOptions::FULL_TRACE);
|
||||
RunOutputs run_outputs;
|
||||
EXPECT_EQ(run_outputs.step_stats().dev_stats_size(), 0);
|
||||
|
||||
Status s = session->RunWithOpts(run_options, inputs, output_names,
|
||||
target_nodes, &outputs, &run_outputs);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
ASSERT_EQ(1, outputs.size());
|
||||
// The first output should be initialized and have the correct
|
||||
// output.
|
||||
auto mat = outputs[0].matrix<float>();
|
||||
ASSERT_TRUE(outputs[0].IsInitialized());
|
||||
EXPECT_FLOAT_EQ(5.0, mat(0, 0));
|
||||
|
||||
// Checks RunOutputs is well-formed
|
||||
ASSERT_TRUE(run_outputs.has_step_stats());
|
||||
EXPECT_EQ(run_outputs.step_stats().dev_stats_size(), 2);
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
|
||||
GraphDef def;
|
||||
Graph g(OpRegistry::Global());
|
||||
|
@ -6,6 +6,8 @@ option java_outer_classname = "ConfigProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
import "tensorflow/core/framework/step_stats.proto";
|
||||
|
||||
message GPUOptions {
|
||||
// A value between 0 and 1 that indicates what fraction of the
|
||||
// available GPU memory to pre-allocate for each process. 1 means
|
||||
@ -130,3 +132,20 @@ message ConfigProto {
|
||||
// Options that apply to all graphs.
|
||||
GraphOptions graph_options = 10;
|
||||
};
|
||||
|
||||
// EXPERIMENTAL. Options for a single Run() call.
|
||||
message RunOptions {
|
||||
enum TraceLevel {
|
||||
NO_TRACE = 0;
|
||||
FULL_TRACE = 1;
|
||||
}
|
||||
TraceLevel trace_level = 1;
|
||||
}
|
||||
|
||||
// EXPERIMENTAL. Metadata output (i.e., non-Tensor) for a single Run() call.
|
||||
message RunOutputs {
|
||||
// Statistics traced for this step. Populated if tracing is turned on via the
|
||||
// "RunOptions" proto.
|
||||
// EXPERIMENTAL: The format and set of events may change in future versions.
|
||||
StepStats step_stats = 1;
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/config.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -114,6 +115,20 @@ class Session {
|
||||
const std::vector<string>& target_node_names,
|
||||
std::vector<Tensor>* outputs) = 0;
|
||||
|
||||
/// \brief Like `Run`, but allows users to pass in a `RunOptions` proto and
|
||||
/// to retrieve non-Tensor metadata output via a `RunOutputs` proto for this
|
||||
/// step.
|
||||
/// NOTE: This API is still experimental and may change.
|
||||
virtual Status RunWithOpts(
|
||||
const RunOptions& run_options,
|
||||
const std::vector<std::pair<string, Tensor> >& inputs,
|
||||
const std::vector<string>& output_tensor_names,
|
||||
const std::vector<string>& target_node_names,
|
||||
std::vector<Tensor>* outputs, RunOutputs* run_outputs) {
|
||||
return errors::Unimplemented(
|
||||
"RunWithOpts() is not supported for this session.");
|
||||
}
|
||||
|
||||
/// \brief Sets up a graph for partial execution. All future feeds and
|
||||
/// fetches are specified by 'input_names' and 'output_names'. Returns
|
||||
/// 'handle' that can be used to perform a sequence of partial feeds and
|
||||
|
@ -124,9 +124,9 @@ def all_libraries(module_to_name, members, documented):
|
||||
|
||||
_hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange",
|
||||
"HistogramProto", "ConfigProto", "NodeDef", "GraphDef",
|
||||
"GPUOptions", "GraphOptions", "SessionInterface",
|
||||
"BaseSession", "NameAttrList", "AttrValue",
|
||||
"TensorArray", "OptimizerOptions",
|
||||
"GPUOptions", "GraphOptions", "RunOptions", "RunOutputs",
|
||||
"SessionInterface", "BaseSession", "NameAttrList",
|
||||
"AttrValue", "TensorArray", "OptimizerOptions",
|
||||
"CollectionDef", "MetaGraphDef", "QueueRunnerDef",
|
||||
"SaverDef", "VariableDef", "TestCase",
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user