Switch DirectSession to use _Arg and _Retval ops for feeding and fetching.
This change reduces the overhead imposed by string processing and rendezvous invocation in the DirectSession::Run() call by 1--2 microseconds per value fed or fetched. RELNOTES: Improved DirectSession::Run() overhead and error checking. Feeding a value of the wrong type will now synchronously raise an INVALID_ARGUMENT error instead of asynchronously raising an INTERNAL error. Code that depends on the (undefined) behavior when feeding a tensor of the wrong type may need to be updated. Change: 153797943
This commit is contained in:
parent
b0594e1b82
commit
858e0afcc4
@ -1563,6 +1563,7 @@ tf_cuda_library(
|
|||||||
":lib_internal",
|
":lib_internal",
|
||||||
":proto_text",
|
":proto_text",
|
||||||
":protos_all_cc",
|
":protos_all_cc",
|
||||||
|
"//tensorflow/core/kernels:function_ops",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -30,6 +30,11 @@ struct BuildGraphOptions {
|
|||||||
// the former via "ref" fetch_endpoints.
|
// the former via "ref" fetch_endpoints.
|
||||||
std::vector<string> target_nodes;
|
std::vector<string> target_nodes;
|
||||||
|
|
||||||
|
// If `true`, uses Arg/Retval to implement feeds/fetches; otherwise
|
||||||
|
// uses Recv/Send to implement feeds/fetches.
|
||||||
|
// TODO(mrry): Remove this when the distributed runtime supports Arg/Retval.
|
||||||
|
bool use_function_convention = false;
|
||||||
|
|
||||||
string DebugString() const;
|
string DebugString() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -361,7 +361,6 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(yuanbyu): Simplify by treating Run() as "PRunSetup(); PRun()".
|
|
||||||
Status DirectSession::Run(const NamedTensorList& inputs,
|
Status DirectSession::Run(const NamedTensorList& inputs,
|
||||||
const std::vector<string>& output_names,
|
const std::vector<string>& output_names,
|
||||||
const std::vector<string>& target_nodes,
|
const std::vector<string>& target_nodes,
|
||||||
@ -426,13 +425,34 @@ Status DirectSession::Run(const RunOptions& run_options,
|
|||||||
executor_step_count, input_tensor_names, output_names, target_nodes));
|
executor_step_count, input_tensor_names, output_names, target_nodes));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure a call frame for the step, which we use to feed and
|
||||||
|
// fetch values to and from the executors.
|
||||||
|
FunctionCallFrame call_frame(executors_and_keys->input_types,
|
||||||
|
executors_and_keys->output_types);
|
||||||
|
gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
|
||||||
|
for (const auto& it : inputs) {
|
||||||
|
if (it.second.dtype() == DT_RESOURCE) {
|
||||||
|
Tensor tensor_from_handle;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ResourceHandleToInputTensor(it.second, &tensor_from_handle));
|
||||||
|
feed_args[executors_and_keys->input_name_to_index[it.first]] =
|
||||||
|
tensor_from_handle;
|
||||||
|
} else {
|
||||||
|
feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Status s = call_frame.SetArgs(feed_args);
|
||||||
|
if (errors::IsInternal(s)) {
|
||||||
|
return errors::InvalidArgument(s.error_message());
|
||||||
|
} else if (!s.ok()) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
// Create a run state and start execution.
|
// Create a run state and start execution.
|
||||||
RunState run_state(args.step_id, &devices_);
|
RunState run_state(args.step_id, &devices_);
|
||||||
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
|
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
|
||||||
CancellationManager step_cancellation_manager;
|
CancellationManager step_cancellation_manager;
|
||||||
|
args.call_frame = &call_frame;
|
||||||
// Send inputs.
|
|
||||||
TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez));
|
|
||||||
|
|
||||||
// Start parallel Executors.
|
// Start parallel Executors.
|
||||||
const size_t num_executors = executors_and_keys->items.size();
|
const size_t num_executors = executors_and_keys->items.size();
|
||||||
@ -535,8 +555,22 @@ Status DirectSession::Run(const RunOptions& run_options,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Receive outputs.
|
// Receive outputs.
|
||||||
TF_RETURN_IF_ERROR(
|
if (outputs) {
|
||||||
RecvOutputs(output_names, executors_and_keys, &run_state, outputs));
|
std::vector<Tensor> sorted_outputs;
|
||||||
|
Status s = call_frame.ConsumeRetvals(&sorted_outputs);
|
||||||
|
if (errors::IsInternal(s)) {
|
||||||
|
return errors::InvalidArgument(s.error_message());
|
||||||
|
} else if (!s.ok()) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
outputs->clear();
|
||||||
|
outputs->reserve(sorted_outputs.size());
|
||||||
|
for (const string& output_name : output_names) {
|
||||||
|
outputs->emplace_back(
|
||||||
|
std::move(sorted_outputs[executors_and_keys
|
||||||
|
->output_name_to_index[output_name]]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Save the output tensors of this run we choose to keep.
|
// Save the output tensors of this run we choose to keep.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -706,11 +740,11 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
|
|||||||
CheckFetch(inputs, output_names, executors_and_keys, run_state));
|
CheckFetch(inputs, output_names, executors_and_keys, run_state));
|
||||||
|
|
||||||
// Send inputs.
|
// Send inputs.
|
||||||
Status s = SendInputs(inputs, executors_and_keys, run_state->rendez);
|
Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
|
||||||
|
|
||||||
// Receive outputs.
|
// Receive outputs.
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
s = RecvOutputs(output_names, executors_and_keys, run_state, outputs);
|
s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save the output tensors of this run we choose to keep.
|
// Save the output tensors of this run we choose to keep.
|
||||||
@ -770,16 +804,17 @@ Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DirectSession::SendInputs(const NamedTensorList& inputs,
|
Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
|
||||||
const ExecutorsAndKeys* executors_and_keys,
|
const ExecutorsAndKeys* executors_and_keys,
|
||||||
IntraProcessRendezvous* rendez) {
|
IntraProcessRendezvous* rendez) {
|
||||||
Status s;
|
Status s;
|
||||||
Rendezvous::ParsedKey parsed;
|
Rendezvous::ParsedKey parsed;
|
||||||
// Insert the input tensors into the local rendezvous by their
|
// Insert the input tensors into the local rendezvous by their
|
||||||
// rendezvous key.
|
// rendezvous key.
|
||||||
for (const auto& input : inputs) {
|
for (const auto& input : inputs) {
|
||||||
auto it = executors_and_keys->input_keys.find(input.first);
|
auto it =
|
||||||
if (it == executors_and_keys->input_keys.end()) {
|
executors_and_keys->input_name_to_rendezvous_key.find(input.first);
|
||||||
|
if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
|
||||||
return errors::Internal("'", input.first, "' is not a pre-defined feed.");
|
return errors::Internal("'", input.first, "' is not a pre-defined feed.");
|
||||||
}
|
}
|
||||||
const string& input_key = it->second;
|
const string& input_key = it->second;
|
||||||
@ -808,10 +843,10 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
|
Status DirectSession::RecvPRunOutputs(
|
||||||
const ExecutorsAndKeys* executors_and_keys,
|
const std::vector<string>& output_names,
|
||||||
RunState* run_state,
|
const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
|
||||||
std::vector<Tensor>* outputs) {
|
std::vector<Tensor>* outputs) {
|
||||||
Status s;
|
Status s;
|
||||||
if (!output_names.empty()) {
|
if (!output_names.empty()) {
|
||||||
outputs->resize(output_names.size());
|
outputs->resize(output_names.size());
|
||||||
@ -822,8 +857,9 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
|
|||||||
for (size_t output_offset = 0; output_offset < output_names.size();
|
for (size_t output_offset = 0; output_offset < output_names.size();
|
||||||
++output_offset) {
|
++output_offset) {
|
||||||
const string& output_name = output_names[output_offset];
|
const string& output_name = output_names[output_offset];
|
||||||
auto it = executors_and_keys->output_keys.find(output_name);
|
auto it =
|
||||||
if (it == executors_and_keys->output_keys.end()) {
|
executors_and_keys->output_name_to_rendezvous_key.find(output_name);
|
||||||
|
if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
|
||||||
return errors::Internal("'", output_name,
|
return errors::Internal("'", output_name,
|
||||||
"' is not a pre-defined fetch.");
|
"' is not a pre-defined fetch.");
|
||||||
}
|
}
|
||||||
@ -987,14 +1023,16 @@ Status DirectSession::GetOrCreateExecutors(
|
|||||||
options.feed_endpoints = inputs_sorted;
|
options.feed_endpoints = inputs_sorted;
|
||||||
options.fetch_endpoints = outputs_sorted;
|
options.fetch_endpoints = outputs_sorted;
|
||||||
options.target_nodes = tn_sorted;
|
options.target_nodes = tn_sorted;
|
||||||
|
options.use_function_convention = !run_state_args->is_partial_run;
|
||||||
|
|
||||||
std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
|
std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
|
||||||
|
|
||||||
// The executor_lock_ is intentionally released while executor is
|
// The executor_lock_ is intentionally released while executor is
|
||||||
// being created.
|
// being created.
|
||||||
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
|
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &ek->flib_def,
|
||||||
CreateGraphs(options, &graphs, &ek->flib_def, run_state_args));
|
run_state_args, &ek->input_types,
|
||||||
|
&ek->output_types));
|
||||||
|
|
||||||
if (run_state_args->is_partial_run) {
|
if (run_state_args->is_partial_run) {
|
||||||
ek->graph = std::move(run_state_args->graph);
|
ek->graph = std::move(run_state_args->graph);
|
||||||
@ -1079,17 +1117,37 @@ Status DirectSession::GetOrCreateExecutors(
|
|||||||
item->executor.reset(executor);
|
item->executor.reset(executor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the rendezvous keys to avoid recomputing them every time.
|
// Cache the mapping from input/output names to graph elements to
|
||||||
//
|
// avoid recomputing it every time.
|
||||||
// We always use the first device as the device name portion of the
|
if (!run_state_args->is_partial_run) {
|
||||||
// key, even if we're feeding another graph.
|
// For regular `Run()`, we use the function calling convention, and so
|
||||||
for (const string& input : inputs) {
|
// maintain a mapping from input/output names to
|
||||||
ek->input_keys[input] = GetRendezvousKey(
|
// argument/return-value ordinal index.
|
||||||
input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
|
for (size_t i = 0; i < inputs_sorted.size(); ++i) {
|
||||||
}
|
const string& input = inputs_sorted[i];
|
||||||
for (const string& output : outputs) {
|
ek->input_name_to_index[input] = i;
|
||||||
ek->output_keys[output] = GetRendezvousKey(
|
}
|
||||||
output, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
|
for (size_t i = 0; i < outputs_sorted.size(); ++i) {
|
||||||
|
const string& output = outputs_sorted[i];
|
||||||
|
ek->output_name_to_index[output] = i;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For `PRun()`, we use the rendezvous calling convention, and so
|
||||||
|
// maintain a mapping from input/output names to rendezvous keys.
|
||||||
|
//
|
||||||
|
// We always use the first device as the device name portion of the
|
||||||
|
// key, even if we're feeding another graph.
|
||||||
|
for (size_t i = 0; i < inputs_sorted.size(); ++i) {
|
||||||
|
const string& input = inputs_sorted[i];
|
||||||
|
ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
|
||||||
|
input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < outputs_sorted.size(); ++i) {
|
||||||
|
const string& output = outputs_sorted[i];
|
||||||
|
ek->output_name_to_rendezvous_key[output] =
|
||||||
|
GetRendezvousKey(output, device_set_.client_device()->attributes(),
|
||||||
|
FrameAndIter(0, 0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reacquire the lock, try to insert into the map.
|
// Reacquire the lock, try to insert into the map.
|
||||||
@ -1110,7 +1168,8 @@ Status DirectSession::CreateGraphs(
|
|||||||
const BuildGraphOptions& subgraph_options,
|
const BuildGraphOptions& subgraph_options,
|
||||||
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
||||||
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||||
RunStateArgs* run_state_args) {
|
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
||||||
|
DataTypeVector* output_types) {
|
||||||
mutex_lock l(graph_def_lock_);
|
mutex_lock l(graph_def_lock_);
|
||||||
std::unique_ptr<SimpleClientGraph> client_graph;
|
std::unique_ptr<SimpleClientGraph> client_graph;
|
||||||
|
|
||||||
@ -1135,6 +1194,23 @@ Status DirectSession::CreateGraphs(
|
|||||||
execution_state->BuildGraph(subgraph_options, &client_graph));
|
execution_state->BuildGraph(subgraph_options, &client_graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (subgraph_options.feed_endpoints.size() !=
|
||||||
|
client_graph->feed_types.size()) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Graph pruning failed: requested number of feed endpoints = ",
|
||||||
|
subgraph_options.feed_endpoints.size(),
|
||||||
|
" versus number of pruned feed endpoints = ",
|
||||||
|
client_graph->feed_types.size());
|
||||||
|
}
|
||||||
|
if (subgraph_options.fetch_endpoints.size() !=
|
||||||
|
client_graph->fetch_types.size()) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Graph pruning failed: requested number of fetch endpoints = ",
|
||||||
|
subgraph_options.fetch_endpoints.size(),
|
||||||
|
" versus number of pruned fetch endpoints = ",
|
||||||
|
client_graph->fetch_types.size());
|
||||||
|
}
|
||||||
|
|
||||||
auto current_stateful_placements = execution_state->GetStatefulPlacements();
|
auto current_stateful_placements = execution_state->GetStatefulPlacements();
|
||||||
// Update our current state based on the execution_state's
|
// Update our current state based on the execution_state's
|
||||||
// placements. If there are any mismatches for a node,
|
// placements. If there are any mismatches for a node,
|
||||||
@ -1240,6 +1316,8 @@ Status DirectSession::CreateGraphs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
*flib_def = std::move(client_graph->flib_def);
|
*flib_def = std::move(client_graph->flib_def);
|
||||||
|
std::swap(*input_types, client_graph->feed_types);
|
||||||
|
std::swap(*output_types, client_graph->fetch_types);
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,8 +132,13 @@ class DirectSession : public Session {
|
|||||||
NameNodeMap name_to_node;
|
NameNodeMap name_to_node;
|
||||||
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
||||||
std::vector<PerPartitionExecutorsAndLib> items;
|
std::vector<PerPartitionExecutorsAndLib> items;
|
||||||
std::unordered_map<string, string> input_keys;
|
std::unordered_map<string, size_t> input_name_to_index;
|
||||||
std::unordered_map<string, string> output_keys;
|
std::unordered_map<string, string> input_name_to_rendezvous_key;
|
||||||
|
std::unordered_map<string, size_t> output_name_to_index;
|
||||||
|
std::unordered_map<string, string> output_name_to_rendezvous_key;
|
||||||
|
|
||||||
|
DataTypeVector input_types;
|
||||||
|
DataTypeVector output_types;
|
||||||
};
|
};
|
||||||
|
|
||||||
// For each live partial execution, the session maintains a RunState.
|
// For each live partial execution, the session maintains a RunState.
|
||||||
@ -187,7 +192,8 @@ class DirectSession : public Session {
|
|||||||
const BuildGraphOptions& options,
|
const BuildGraphOptions& options,
|
||||||
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
||||||
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||||
RunStateArgs* run_state_args);
|
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
||||||
|
DataTypeVector* output_types);
|
||||||
|
|
||||||
::tensorflow::Status ExtendLocked(const GraphDef& graph)
|
::tensorflow::Status ExtendLocked(const GraphDef& graph)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
|
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
|
||||||
@ -196,17 +202,17 @@ class DirectSession : public Session {
|
|||||||
const Tensor& resource_tensor, Tensor* retrieved_tensor);
|
const Tensor& resource_tensor, Tensor* retrieved_tensor);
|
||||||
|
|
||||||
// Feeds more inputs to the executors, triggering further execution.
|
// Feeds more inputs to the executors, triggering further execution.
|
||||||
::tensorflow::Status SendInputs(
|
::tensorflow::Status SendPRunInputs(
|
||||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||||
const ExecutorsAndKeys* executors_and_keys,
|
const ExecutorsAndKeys* executors_and_keys,
|
||||||
IntraProcessRendezvous* rendez);
|
IntraProcessRendezvous* rendez);
|
||||||
|
|
||||||
// Fetches more outputs from the executors. It waits until the output
|
// Fetches more outputs from the executors. It waits until the output
|
||||||
// tensors are computed.
|
// tensors are computed.
|
||||||
::tensorflow::Status RecvOutputs(const std::vector<string>& output_names,
|
::tensorflow::Status RecvPRunOutputs(
|
||||||
const ExecutorsAndKeys* executors_and_keys,
|
const std::vector<string>& output_names,
|
||||||
RunState* run_state,
|
const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
|
||||||
std::vector<Tensor>* outputs);
|
std::vector<Tensor>* outputs);
|
||||||
|
|
||||||
// Check if the specified fetches can be computed from the feeds
|
// Check if the specified fetches can be computed from the feeds
|
||||||
// that we have already provided.
|
// that we have already provided.
|
||||||
|
@ -130,9 +130,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Call RewriteGraphForExecution
|
// Call RewriteGraphForExecution
|
||||||
|
subgraph::RewriteGraphMetadata metadata;
|
||||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||||
graph_to_run.get(), input_names, output_names, {} /* target nodes */,
|
graph_to_run.get(), input_names, output_names, {} /* target nodes */,
|
||||||
cpu_device_->attributes()));
|
cpu_device_->attributes(), false /* use_function_convention */,
|
||||||
|
&metadata));
|
||||||
|
|
||||||
// Create the local executor and the Rendezvous for fetching back the
|
// Create the local executor and the Rendezvous for fetching back the
|
||||||
// constants.
|
// constants.
|
||||||
|
@ -21,9 +21,9 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Replaces ReadVariableOp nodes which are only used by Sends and sinks with
|
// Replaces ReadVariableOp nodes which are only used by Sends, sinks,
|
||||||
// _UnsafeReadVariable nodes, as this transforamtion is safe and will improve
|
// and function Retvals with _UnsafeReadVariable nodes, as this
|
||||||
// performance.
|
// transformation is safe and will improve performance.
|
||||||
class ResourceVariableReadPass : public GraphOptimizationPass {
|
class ResourceVariableReadPass : public GraphOptimizationPass {
|
||||||
public:
|
public:
|
||||||
Status Run(const GraphOptimizationPassOptions& options) override {
|
Status Run(const GraphOptimizationPassOptions& options) override {
|
||||||
@ -43,7 +43,8 @@ class ResourceVariableReadPass : public GraphOptimizationPass {
|
|||||||
if (n->type_string() == "ReadVariableOp") {
|
if (n->type_string() == "ReadVariableOp") {
|
||||||
bool skip = false;
|
bool skip = false;
|
||||||
for (const Edge* e : n->out_edges()) {
|
for (const Edge* e : n->out_edges()) {
|
||||||
if (!e->dst()->IsSend() && e->dst()->name() != "_SINK") {
|
if (!e->dst()->IsSend() && e->dst()->type_string() != "_Retval" &&
|
||||||
|
e->dst()->name() != "_SINK") {
|
||||||
skip = true;
|
skip = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -284,9 +284,11 @@ Status SimpleGraphExecutionState::InitBaseGraph(
|
|||||||
if (session_options_ &&
|
if (session_options_ &&
|
||||||
session_options_->config.graph_options().place_pruned_graph()) {
|
session_options_->config.graph_options().place_pruned_graph()) {
|
||||||
// Rewrite the graph before placement.
|
// Rewrite the graph before placement.
|
||||||
|
rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
|
||||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||||
new_graph.get(), options.feed_endpoints, options.fetch_endpoints,
|
new_graph.get(), options.feed_endpoints, options.fetch_endpoints,
|
||||||
options.target_nodes, device_set_->client_device()->attributes()));
|
options.target_nodes, device_set_->client_device()->attributes(),
|
||||||
|
options.use_function_convention, rewrite_metadata_.get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save stateful placements before placing.
|
// Save stateful placements before placing.
|
||||||
@ -333,15 +335,26 @@ Status SimpleGraphExecutionState::BuildGraph(
|
|||||||
std::unique_ptr<Graph> ng(new Graph(flib_def_.get()));
|
std::unique_ptr<Graph> ng(new Graph(flib_def_.get()));
|
||||||
CopyGraph(*graph_, ng.get());
|
CopyGraph(*graph_, ng.get());
|
||||||
|
|
||||||
|
subgraph::RewriteGraphMetadata rewrite_metadata;
|
||||||
if (session_options_ == nullptr ||
|
if (session_options_ == nullptr ||
|
||||||
!session_options_->config.graph_options().place_pruned_graph()) {
|
!session_options_->config.graph_options().place_pruned_graph()) {
|
||||||
// Extract the subset of the graph that needs to be run, adding feed/fetch
|
// Extract the subset of the graph that needs to be run, adding feed/fetch
|
||||||
// ops as needed.
|
// ops as needed.
|
||||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||||
ng.get(), options.feed_endpoints, options.fetch_endpoints,
|
ng.get(), options.feed_endpoints, options.fetch_endpoints,
|
||||||
options.target_nodes, device_set_->client_device()->attributes()));
|
options.target_nodes, device_set_->client_device()->attributes(),
|
||||||
|
options.use_function_convention, &rewrite_metadata));
|
||||||
|
} else {
|
||||||
|
// This SimpleGraphExecutionState represents a graph that was
|
||||||
|
// pruned when this was constructed, so we copy the metadata from
|
||||||
|
// a member variable.
|
||||||
|
CHECK(rewrite_metadata_);
|
||||||
|
rewrite_metadata = *rewrite_metadata_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size());
|
||||||
|
CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size());
|
||||||
|
|
||||||
// Make a fresh copy of the function library for the client graph.
|
// Make a fresh copy of the function library for the client graph.
|
||||||
std::unique_ptr<FunctionLibraryDefinition> flib(
|
std::unique_ptr<FunctionLibraryDefinition> flib(
|
||||||
new FunctionLibraryDefinition(*flib_def_));
|
new FunctionLibraryDefinition(*flib_def_));
|
||||||
@ -363,7 +376,8 @@ Status SimpleGraphExecutionState::BuildGraph(
|
|||||||
// since the local CostModel used to record its stats is sized by
|
// since the local CostModel used to record its stats is sized by
|
||||||
// the largest node id.
|
// the largest node id.
|
||||||
std::unique_ptr<SimpleClientGraph> dense_copy(
|
std::unique_ptr<SimpleClientGraph> dense_copy(
|
||||||
new SimpleClientGraph(std::move(flib)));
|
new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types,
|
||||||
|
rewrite_metadata.fetch_types));
|
||||||
CopyGraph(*ng, &dense_copy->graph);
|
CopyGraph(*ng, &dense_copy->graph);
|
||||||
|
|
||||||
// TODO(vrv): We should check invariants of the graph here.
|
// TODO(vrv): We should check invariants of the graph here.
|
||||||
|
@ -39,6 +39,10 @@ struct SessionOptions;
|
|||||||
class StepStats;
|
class StepStats;
|
||||||
class Timeline;
|
class Timeline;
|
||||||
|
|
||||||
|
namespace subgraph {
|
||||||
|
struct RewriteGraphMetadata;
|
||||||
|
}
|
||||||
|
|
||||||
struct SimpleGraphExecutionStateOptions {
|
struct SimpleGraphExecutionStateOptions {
|
||||||
const DeviceSet* device_set = nullptr;
|
const DeviceSet* device_set = nullptr;
|
||||||
const SessionOptions* session_options = nullptr;
|
const SessionOptions* session_options = nullptr;
|
||||||
@ -50,13 +54,19 @@ struct SimpleGraphExecutionStateOptions {
|
|||||||
// A SimpleClientGraph is simply a sub-graph of the full graph as induced by
|
// A SimpleClientGraph is simply a sub-graph of the full graph as induced by
|
||||||
// BuildGraphOptions.
|
// BuildGraphOptions.
|
||||||
struct SimpleClientGraph {
|
struct SimpleClientGraph {
|
||||||
explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib)
|
explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
|
||||||
: flib_def(std::move(flib)), graph(flib_def.get()) {}
|
DataTypeVector feed_types,
|
||||||
|
DataTypeVector fetch_types)
|
||||||
|
: flib_def(std::move(flib)),
|
||||||
|
graph(flib_def.get()),
|
||||||
|
feed_types(std::move(feed_types)),
|
||||||
|
fetch_types(std::move(fetch_types)) {}
|
||||||
// Each client-graph gets its own function library since optimization passes
|
// Each client-graph gets its own function library since optimization passes
|
||||||
// post rewrite for execution might want to introduce new functions.
|
// post rewrite for execution might want to introduce new functions.
|
||||||
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
||||||
Graph graph;
|
Graph graph;
|
||||||
int32 placement_version;
|
DataTypeVector feed_types;
|
||||||
|
DataTypeVector fetch_types;
|
||||||
};
|
};
|
||||||
|
|
||||||
// SimpleGraphExecutionState is responsible for generating an
|
// SimpleGraphExecutionState is responsible for generating an
|
||||||
@ -190,6 +200,10 @@ class SimpleGraphExecutionState {
|
|||||||
// and may be updated by a graph optimization pass.
|
// and may be updated by a graph optimization pass.
|
||||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||||
|
|
||||||
|
// `rewrite_metadata_` is only set for SimpleGraphExecutionState
|
||||||
|
// objects created by `MakeForPrunedGraph()`.
|
||||||
|
std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
|
||||||
|
|
||||||
// The dataflow graph owned by this object.
|
// The dataflow graph owned by this object.
|
||||||
Graph* graph_;
|
Graph* graph_;
|
||||||
|
|
||||||
|
@ -789,7 +789,7 @@ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
|
|||||||
rets->clear();
|
rets->clear();
|
||||||
rets->reserve(rets_.size());
|
rets->reserve(rets_.size());
|
||||||
for (size_t i = 0; i < rets_.size(); ++i) {
|
for (size_t i = 0; i < rets_.size(); ++i) {
|
||||||
auto item = rets_[i];
|
const auto& item = rets_[i];
|
||||||
if (item.has_val) {
|
if (item.has_val) {
|
||||||
rets->push_back(item.val);
|
rets->push_back(item.val);
|
||||||
} else {
|
} else {
|
||||||
@ -799,6 +799,19 @@ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) {
|
||||||
|
rets->clear();
|
||||||
|
rets->reserve(rets_.size());
|
||||||
|
for (size_t i = 0; i < rets_.size(); ++i) {
|
||||||
|
if (rets_[i].has_val) {
|
||||||
|
rets->emplace_back(std::move(rets_[i].val));
|
||||||
|
} else {
|
||||||
|
return errors::Internal("Retval[", i, "] does not have value");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
|
Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
|
||||||
if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
|
if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
|
||||||
return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
|
return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
|
||||||
|
@ -259,6 +259,7 @@ class FunctionCallFrame {
|
|||||||
// Caller methods.
|
// Caller methods.
|
||||||
Status SetArgs(gtl::ArraySlice<Tensor> args);
|
Status SetArgs(gtl::ArraySlice<Tensor> args);
|
||||||
Status GetRetvals(std::vector<Tensor>* rets) const;
|
Status GetRetvals(std::vector<Tensor>* rets) const;
|
||||||
|
Status ConsumeRetvals(std::vector<Tensor>* rets);
|
||||||
|
|
||||||
// Callee methods.
|
// Callee methods.
|
||||||
Status GetArg(int index, Tensor* val) const;
|
Status GetArg(int index, Tensor* val) const;
|
||||||
|
@ -55,8 +55,13 @@ namespace {
|
|||||||
// state).
|
// state).
|
||||||
static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
|
static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
|
||||||
const gtl::ArraySlice<string>& fed_outputs,
|
const gtl::ArraySlice<string>& fed_outputs,
|
||||||
subgraph::NameIndex* name_index) {
|
bool use_function_convention,
|
||||||
for (const string& t : fed_outputs) {
|
subgraph::NameIndex* name_index,
|
||||||
|
DataTypeVector* out_feed_types) {
|
||||||
|
out_feed_types->clear();
|
||||||
|
out_feed_types->reserve(fed_outputs.size());
|
||||||
|
for (size_t i = 0; i < fed_outputs.size(); ++i) {
|
||||||
|
const string& t = fed_outputs[i];
|
||||||
TensorId id(ParseTensorName(t));
|
TensorId id(ParseTensorName(t));
|
||||||
|
|
||||||
auto iter = name_index->find(id.first);
|
auto iter = name_index->find(id.first);
|
||||||
@ -71,17 +76,31 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Node* recv_node;
|
Node* recv_node;
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
|
if (!use_function_convention) {
|
||||||
"_Recv")
|
TF_RETURN_IF_ERROR(
|
||||||
.Attr("tensor_type", BaseType(n->output_type(id.second)))
|
NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
|
||||||
.Attr("tensor_name", t)
|
"_Recv")
|
||||||
.Attr("send_device", device_info.name())
|
.Attr("tensor_type", BaseType(n->output_type(id.second)))
|
||||||
.Attr("recv_device", device_info.name())
|
.Attr("tensor_name", t)
|
||||||
.Attr("send_device_incarnation",
|
.Attr("send_device", device_info.name())
|
||||||
static_cast<int64>(device_info.incarnation()))
|
.Attr("recv_device", device_info.name())
|
||||||
.Attr("client_terminated", true)
|
.Attr("send_device_incarnation",
|
||||||
.Finalize(g, &recv_node));
|
static_cast<int64>(device_info.incarnation()))
|
||||||
|
.Attr("client_terminated", true)
|
||||||
|
.Finalize(g, &recv_node));
|
||||||
|
} else {
|
||||||
|
// NOTE(mrry): We must include the index as part of the node
|
||||||
|
// name, because _Arg is a "stateful" kernel and therefore
|
||||||
|
// its name must uniquely identify a kernel instance across all
|
||||||
|
// graphs in the same session.
|
||||||
|
TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_arg_", id.first, "_",
|
||||||
|
id.second, "_", i),
|
||||||
|
"_Arg")
|
||||||
|
.Attr("T", BaseType(n->output_type(id.second)))
|
||||||
|
.Attr("index", static_cast<int32>(i))
|
||||||
|
.Finalize(g, &recv_node));
|
||||||
|
}
|
||||||
recv_node->set_assigned_device_name(device_info.name());
|
recv_node->set_assigned_device_name(device_info.name());
|
||||||
|
|
||||||
// Copy the _output_shapes from the original node to the feed node,
|
// Copy the _output_shapes from the original node to the feed node,
|
||||||
@ -130,6 +149,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
|
|||||||
}
|
}
|
||||||
g->RemoveEdge(e);
|
g->RemoveEdge(e);
|
||||||
}
|
}
|
||||||
|
out_feed_types->push_back(BaseType(n->output_type(id.second)));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -181,9 +201,14 @@ namespace subgraph {
|
|||||||
|
|
||||||
Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
|
Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
|
||||||
const gtl::ArraySlice<string>& fetch_outputs,
|
const gtl::ArraySlice<string>& fetch_outputs,
|
||||||
NameIndex* name_index, std::vector<Node*>* fetch_nodes) {
|
bool use_function_convention, NameIndex* name_index,
|
||||||
fetch_nodes->clear();
|
std::vector<Node*>* out_fetch_nodes,
|
||||||
for (const string& t : fetch_outputs) {
|
DataTypeVector* out_fetch_types) {
|
||||||
|
out_fetch_nodes->clear();
|
||||||
|
out_fetch_nodes->reserve(fetch_outputs.size());
|
||||||
|
for (size_t i = 0; i < fetch_outputs.size(); ++i) {
|
||||||
|
const string& t = fetch_outputs[i];
|
||||||
|
|
||||||
// Parse t into node_name and output_index.
|
// Parse t into node_name and output_index.
|
||||||
TensorId id(ParseTensorName(t));
|
TensorId id(ParseTensorName(t));
|
||||||
|
|
||||||
@ -213,25 +238,39 @@ Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
|
|||||||
|
|
||||||
// Create the fetch Node and connect it up
|
// Create the fetch Node and connect it up
|
||||||
Node* send_node;
|
Node* send_node;
|
||||||
TF_RETURN_IF_ERROR(
|
if (!use_function_convention) {
|
||||||
NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
|
TF_RETURN_IF_ERROR(
|
||||||
"_Send")
|
NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
|
||||||
.Input(n, id.second)
|
"_Send")
|
||||||
.Attr("tensor_name", t)
|
.Input(n, id.second)
|
||||||
.Attr("send_device", device_info.name())
|
.Attr("tensor_name", t)
|
||||||
.Attr("recv_device", device_info.name())
|
.Attr("send_device", device_info.name())
|
||||||
.Attr("send_device_incarnation",
|
.Attr("recv_device", device_info.name())
|
||||||
static_cast<int64>(device_info.incarnation()))
|
.Attr("send_device_incarnation",
|
||||||
.Attr("client_terminated", true)
|
static_cast<int64>(device_info.incarnation()))
|
||||||
.Finalize(g, &send_node));
|
.Attr("client_terminated", true)
|
||||||
|
.Finalize(g, &send_node));
|
||||||
|
} else {
|
||||||
|
// NOTE(mrry): We must include the index as part of the node
|
||||||
|
// name, because _Retval is a "stateful" kernel and therefore
|
||||||
|
// its name must uniquely identify a kernel instance across all
|
||||||
|
// graphs in the same session.
|
||||||
|
TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_retval_", id.first, "_",
|
||||||
|
id.second, "_", i),
|
||||||
|
"_Retval")
|
||||||
|
.Input(n, id.second)
|
||||||
|
.Attr("T", BaseType(n->output_type(id.second)))
|
||||||
|
.Attr("index", static_cast<int32>(i))
|
||||||
|
.Finalize(g, &send_node));
|
||||||
|
}
|
||||||
send_node->set_assigned_device_name(device_info.name());
|
send_node->set_assigned_device_name(device_info.name());
|
||||||
VLOG(1) << "Created fetch node: " << SummarizeNodeDef(send_node->def());
|
|
||||||
|
|
||||||
// Update the index.
|
// Update the index.
|
||||||
(*name_index)[send_node->name()] = send_node;
|
(*name_index)[send_node->name()] = send_node;
|
||||||
|
|
||||||
g->AddControlEdge(send_node, g->sink_node());
|
g->AddControlEdge(send_node, g->sink_node());
|
||||||
fetch_nodes->push_back(send_node);
|
out_fetch_nodes->push_back(send_node);
|
||||||
|
out_fetch_types->push_back(BaseType(n->output_type(id.second)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -241,7 +280,8 @@ Status RewriteGraphForExecution(
|
|||||||
Graph* g, const gtl::ArraySlice<string>& fed_outputs,
|
Graph* g, const gtl::ArraySlice<string>& fed_outputs,
|
||||||
const gtl::ArraySlice<string>& fetch_outputs,
|
const gtl::ArraySlice<string>& fetch_outputs,
|
||||||
const gtl::ArraySlice<string>& target_node_names,
|
const gtl::ArraySlice<string>& target_node_names,
|
||||||
const DeviceAttributes& device_info) {
|
const DeviceAttributes& device_info, bool use_function_convention,
|
||||||
|
RewriteGraphMetadata* out_metadata) {
|
||||||
if (fetch_outputs.empty() && target_node_names.empty()) {
|
if (fetch_outputs.empty() && target_node_names.empty()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Must specify at least one target to fetch or execute.");
|
"Must specify at least one target to fetch or execute.");
|
||||||
@ -274,18 +314,21 @@ Status RewriteGraphForExecution(
|
|||||||
// currently listed in "fetch_nodes". We pass "name_index" so the index is
|
// currently listed in "fetch_nodes". We pass "name_index" so the index is
|
||||||
// kept up to date.
|
// kept up to date.
|
||||||
if (!fed_outputs.empty()) {
|
if (!fed_outputs.empty()) {
|
||||||
TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, &name_index));
|
TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs,
|
||||||
|
use_function_convention, &name_index,
|
||||||
|
&out_metadata->feed_types));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the fetch nodes, also updating "name_index".
|
// Add the fetch nodes, also updating "name_index".
|
||||||
std::vector<Node*> fetch_nodes;
|
std::vector<Node*> fetch_nodes;
|
||||||
if (!fetch_outputs.empty()) {
|
if (!fetch_outputs.empty()) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(FetchOutputs(g, device_info, fetch_outputs,
|
||||||
FetchOutputs(g, device_info, fetch_outputs, &name_index, &fetch_nodes));
|
use_function_convention, &name_index,
|
||||||
|
&fetch_nodes, &out_metadata->fetch_types));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prune the graph to only compute what is needed for the fetch nodes and the
|
// Prune the graph to only compute what is needed for the fetch nodes and the
|
||||||
// targets nodes.
|
// target nodes.
|
||||||
if (!fetch_nodes.empty() || !target_node_names.empty()) {
|
if (!fetch_nodes.empty() || !target_node_names.empty()) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
PruneForTargets(g, name_index, fetch_nodes, target_node_names));
|
PruneForTargets(g, name_index, fetch_nodes, target_node_names));
|
||||||
|
@ -26,6 +26,18 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace subgraph {
|
namespace subgraph {
|
||||||
|
|
||||||
|
// Information about a graph rewritten by `RewriteGraphForExecution()`.
|
||||||
|
struct RewriteGraphMetadata {
|
||||||
|
// The element type of each tensor fed to this subgraph. The order
|
||||||
|
// of types corresponds to the order of tensor names in
|
||||||
|
// `fed_outputs` when calling `RewriteGraphForExecution()`.
|
||||||
|
DataTypeVector feed_types;
|
||||||
|
// The element type of each tensor fetched from this subgraph. The
|
||||||
|
// order of types corresponds to the order of tensor names in
|
||||||
|
// `fetch_outputs` when calling `RewriteGraphForExecution()`.
|
||||||
|
DataTypeVector fetch_types;
|
||||||
|
};
|
||||||
|
|
||||||
// Rewrite the graph structure of "*g" to deal with feeding node
|
// Rewrite the graph structure of "*g" to deal with feeding node
|
||||||
// outputs, fetching node outputs, and only running a subset of the
|
// outputs, fetching node outputs, and only running a subset of the
|
||||||
// graph. "fed_outputs" and "fetch_outputs" are both lists of
|
// graph. "fed_outputs" and "fetch_outputs" are both lists of
|
||||||
@ -56,7 +68,8 @@ Status RewriteGraphForExecution(
|
|||||||
Graph* g, const gtl::ArraySlice<string>& fed_outputs,
|
Graph* g, const gtl::ArraySlice<string>& fed_outputs,
|
||||||
const gtl::ArraySlice<string>& fetch_outputs,
|
const gtl::ArraySlice<string>& fetch_outputs,
|
||||||
const gtl::ArraySlice<string>& target_node_names,
|
const gtl::ArraySlice<string>& target_node_names,
|
||||||
const DeviceAttributes& device_info);
|
const DeviceAttributes& device_info, bool use_function_convention,
|
||||||
|
RewriteGraphMetadata* out_metadata);
|
||||||
|
|
||||||
typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex;
|
typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex;
|
||||||
|
|
||||||
|
@ -104,7 +104,8 @@ class SubgraphTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string Subgraph(const string& fed_str, const string& fetch_str,
|
string Subgraph(const string& fed_str, const string& fetch_str,
|
||||||
const string& targets_str) {
|
const string& targets_str,
|
||||||
|
bool use_function_convention = false) {
|
||||||
Graph* subgraph = new Graph(OpRegistry::Global());
|
Graph* subgraph = new Graph(OpRegistry::Global());
|
||||||
CopyGraph(*g_, subgraph);
|
CopyGraph(*g_, subgraph);
|
||||||
std::vector<string> fed =
|
std::vector<string> fed =
|
||||||
@ -114,13 +115,18 @@ class SubgraphTest : public ::testing::Test {
|
|||||||
std::vector<string> targets =
|
std::vector<string> targets =
|
||||||
str_util::Split(targets_str, ',', str_util::SkipEmpty());
|
str_util::Split(targets_str, ',', str_util::SkipEmpty());
|
||||||
|
|
||||||
Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch, targets,
|
subgraph::RewriteGraphMetadata metadata;
|
||||||
device_info_);
|
Status s = subgraph::RewriteGraphForExecution(
|
||||||
|
subgraph, fed, fetch, targets, device_info_, use_function_convention,
|
||||||
|
&metadata);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
delete subgraph;
|
delete subgraph;
|
||||||
return s.ToString();
|
return s.ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(fed.size(), metadata.feed_types.size());
|
||||||
|
EXPECT_EQ(fetch.size(), metadata.fetch_types.size());
|
||||||
|
|
||||||
// Replace the graph with the subgraph for the rest of the display program
|
// Replace the graph with the subgraph for the rest of the display program
|
||||||
g_.reset(subgraph);
|
g_.reset(subgraph);
|
||||||
return "OK";
|
return "OK";
|
||||||
@ -178,6 +184,20 @@ TEST_F(SubgraphTest, FedOutputs1) {
|
|||||||
ExpectNodes("W1,W2,_recv_input_1,t1,t2");
|
ExpectNodes("W1,W2,_recv_input_1,t1,t2");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SubgraphTest, FedOutputs1_FunctionConvention) {
|
||||||
|
ExpectOK(
|
||||||
|
"node { name: 'W1' op: 'TestParams' }"
|
||||||
|
"node { name: 'W2' op: 'TestParams' }"
|
||||||
|
"node { name: 'input' op: 'TestInput' }"
|
||||||
|
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
|
||||||
|
"node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
|
||||||
|
"node { name: 't3_a' op: 'TestRelu' input: 't2' }"
|
||||||
|
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
|
||||||
|
EXPECT_EQ("OK",
|
||||||
|
Subgraph("input:1", "", "t2", true /* use_function_convention */));
|
||||||
|
ExpectNodes("W1,W2,_arg_input_1_0,t1,t2");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(SubgraphTest, FedRefNode) {
|
TEST_F(SubgraphTest, FedRefNode) {
|
||||||
ExpectOK(
|
ExpectOK(
|
||||||
"node { name: 'W1' op: 'TestParams' }"
|
"node { name: 'W1' op: 'TestParams' }"
|
||||||
@ -189,7 +209,19 @@ TEST_F(SubgraphTest, FedRefNode) {
|
|||||||
EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
|
EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SubgraphTest, FedOutputs2) {
|
TEST_F(SubgraphTest, FedRefNode_FunctionConvention) {
|
||||||
|
ExpectOK(
|
||||||
|
"node { name: 'W1' op: 'TestParams' }"
|
||||||
|
"node { name: 'W2' op: 'TestParams' }"
|
||||||
|
"node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }");
|
||||||
|
EXPECT_EQ("OK",
|
||||||
|
Subgraph("W1:0", "", "t1", true /* use_function_convention */));
|
||||||
|
ExpectNodes("_arg_W1_0_0,W2,t1");
|
||||||
|
Node* n = FindNode("_arg_W1_0_0");
|
||||||
|
EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SubgraphTest, FedOutputs2_FunctionConvention) {
|
||||||
ExpectOK(
|
ExpectOK(
|
||||||
"node { name: 'W1' op: 'TestParams' }"
|
"node { name: 'W1' op: 'TestParams' }"
|
||||||
"node { name: 'W2' op: 'TestParams' }"
|
"node { name: 'W2' op: 'TestParams' }"
|
||||||
@ -200,8 +232,9 @@ TEST_F(SubgraphTest, FedOutputs2) {
|
|||||||
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
|
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
|
||||||
// We feed input:1, but nothing connects to it, so the _recv(input:1)
|
// We feed input:1, but nothing connects to it, so the _recv(input:1)
|
||||||
// node also disappears.
|
// node also disappears.
|
||||||
EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2"));
|
EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2",
|
||||||
ExpectNodes("_recv_t1_0,_recv_W2_0,t2");
|
true /* use_function_convention */));
|
||||||
|
ExpectNodes("_arg_t1_0_1,_arg_W2_0_2,t2");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SubgraphTest, FetchOutputs1) {
|
TEST_F(SubgraphTest, FetchOutputs1) {
|
||||||
@ -218,6 +251,22 @@ TEST_F(SubgraphTest, FetchOutputs1) {
|
|||||||
"W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0");
|
"W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SubgraphTest, FetchOutputs1_FunctionConvention) {
|
||||||
|
ExpectOK(
|
||||||
|
"node { name: 'W1' op: 'TestParams' }"
|
||||||
|
"node { name: 'W2' op: 'TestParams' }"
|
||||||
|
"node { name: 'input' op: 'TestInput' }"
|
||||||
|
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
|
||||||
|
"node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
|
||||||
|
"node { name: 't3_a' op: 'TestRelu' input: 't2' }"
|
||||||
|
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
|
||||||
|
EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2",
|
||||||
|
true /* use_function_convention */));
|
||||||
|
ExpectNodes(
|
||||||
|
"W1,W2,input,t1,t2,_retval_W2_0_0,_retval_input_1_1,_retval_t1_0_2,_"
|
||||||
|
"retval_t2_0_3");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(SubgraphTest, FetchOutputs2) {
|
TEST_F(SubgraphTest, FetchOutputs2) {
|
||||||
ExpectOK(
|
ExpectOK(
|
||||||
"node { name: 'W1' op: 'TestParams' }"
|
"node { name: 'W1' op: 'TestParams' }"
|
||||||
@ -231,6 +280,20 @@ TEST_F(SubgraphTest, FetchOutputs2) {
|
|||||||
ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0");
|
ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SubgraphTest, FetchOutputs2_FunctionConvention) {
|
||||||
|
ExpectOK(
|
||||||
|
"node { name: 'W1' op: 'TestParams' }"
|
||||||
|
"node { name: 'W2' op: 'TestParams' }"
|
||||||
|
"node { name: 'input' op: 'TestInput' }"
|
||||||
|
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
|
||||||
|
"node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
|
||||||
|
"node { name: 't3_a' op: 'TestRelu' input: 't2' }"
|
||||||
|
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
|
||||||
|
EXPECT_EQ("OK",
|
||||||
|
Subgraph("", "t3_a", "t2", true /* use_function_convention */));
|
||||||
|
ExpectNodes("W1,W2,input,t1,t2,t3_a,_retval_t3_a_0_0");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(SubgraphTest, ChainOfFools) {
|
TEST_F(SubgraphTest, ChainOfFools) {
|
||||||
ExpectOK(
|
ExpectOK(
|
||||||
"node { name: 'a' op: 'TestParams' }"
|
"node { name: 'a' op: 'TestParams' }"
|
||||||
@ -315,7 +378,8 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) {
|
|||||||
REGISTER_OP("In").Output("o: float");
|
REGISTER_OP("In").Output("o: float");
|
||||||
REGISTER_OP("Op").Input("i: float").Output("o: float");
|
REGISTER_OP("Op").Input("i: float").Output("o: float");
|
||||||
|
|
||||||
static void BM_Subgraph(int iters, int num_nodes) {
|
static void BM_SubgraphHelper(int iters, int num_nodes,
|
||||||
|
bool use_function_convention) {
|
||||||
DeviceAttributes device_info;
|
DeviceAttributes device_info;
|
||||||
device_info.set_name("/job:a/replica:0/task:0/cpu:0");
|
device_info.set_name("/job:a/replica:0/task:0/cpu:0");
|
||||||
device_info.set_device_type(DeviceType(DEVICE_CPU).type());
|
device_info.set_device_type(DeviceType(DEVICE_CPU).type());
|
||||||
@ -347,12 +411,26 @@ static void BM_Subgraph(int iters, int num_nodes) {
|
|||||||
while (--iters > 0) {
|
while (--iters > 0) {
|
||||||
Graph* subgraph = new Graph(OpRegistry::Global());
|
Graph* subgraph = new Graph(OpRegistry::Global());
|
||||||
CopyGraph(g, subgraph);
|
CopyGraph(g, subgraph);
|
||||||
TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch,
|
subgraph::RewriteGraphMetadata metadata;
|
||||||
targets, device_info));
|
TF_CHECK_OK(subgraph::RewriteGraphForExecution(
|
||||||
|
subgraph, fed, fetch, targets, device_info, use_function_convention,
|
||||||
|
&metadata));
|
||||||
delete subgraph;
|
delete subgraph;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void BM_Subgraph(int iters, int num_nodes) {
|
||||||
|
BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */);
|
||||||
|
}
|
||||||
|
static void BM_SubgraphFunctionConvention(int iters, int num_nodes) {
|
||||||
|
BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */);
|
||||||
|
}
|
||||||
BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
|
BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
|
||||||
|
BENCHMARK(BM_SubgraphFunctionConvention)
|
||||||
|
->Arg(100)
|
||||||
|
->Arg(1000)
|
||||||
|
->Arg(10000)
|
||||||
|
->Arg(100000);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -820,7 +820,7 @@ class DebugDumpDir(object):
|
|||||||
self._node_op_types[node.name] = node.op
|
self._node_op_types[node.name] = node.op
|
||||||
|
|
||||||
for inp in node.input:
|
for inp in node.input:
|
||||||
if is_copy_node(inp) and node.op == "_Send":
|
if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"):
|
||||||
self._copy_send_nodes.append(node.name)
|
self._copy_send_nodes.append(node.name)
|
||||||
|
|
||||||
if inp.startswith("^"):
|
if inp.startswith("^"):
|
||||||
|
@ -196,7 +196,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
|
|
||||||
with self.assertRaisesWithPredicateMatch(
|
with self.assertRaisesWithPredicateMatch(
|
||||||
errors_impl.InvalidArgumentError,
|
errors_impl.InvalidArgumentError,
|
||||||
lambda e: "The tensor returned for" in str(e)):
|
lambda e: "Retval[0] does not have value" in str(e)):
|
||||||
dead_branch.eval()
|
dead_branch.eval()
|
||||||
|
|
||||||
def testSwitchMergeLess(self):
|
def testSwitchMergeLess(self):
|
||||||
|
@ -147,9 +147,10 @@ Status FoldConstants(const GraphDef& input_graph_def,
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr));
|
ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr));
|
||||||
DeviceAttributes device_attributes;
|
DeviceAttributes device_attributes;
|
||||||
|
subgraph::RewriteGraphMetadata metadata;
|
||||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||||
&input_graph, context.input_names, context.output_names, {},
|
&input_graph, context.input_names, context.output_names, {},
|
||||||
device_attributes));
|
device_attributes, false /* use_function_convention */, &metadata));
|
||||||
bool was_mutated;
|
bool was_mutated;
|
||||||
TF_RETURN_IF_ERROR(DoConstantFoldingWithStatus(
|
TF_RETURN_IF_ERROR(DoConstantFoldingWithStatus(
|
||||||
ConstantFoldingOptions(), nullptr, Env::Default(), nullptr, &input_graph,
|
ConstantFoldingOptions(), nullptr, Env::Default(), nullptr, &input_graph,
|
||||||
|
Loading…
Reference in New Issue
Block a user