From 858e0afcc45c39b6428bf82ab1444323e925cfd8 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 20 Apr 2017 22:22:29 -0800 Subject: [PATCH] Switch DirectSession to use _Arg and _Retval ops for feeding and fetching. This change reduces the overhead imposed by string processing and rendezvous invocation in the DirectSession::Run() call by 1--2 microseconds per value fed or fetched. RELNOTES: Improved DirectSession::Run() overhead and error checking. Feeding a value of the wrong type will now synchronously raise an INVALID_ARGUMENT error instead of asynchronously raising an INTERNAL error. Code that depends on the (undefined) behavior when feeding a tensor of the wrong type may need to be updated. Change: 153797943 --- tensorflow/core/BUILD | 1 + .../core/common_runtime/build_graph_options.h | 5 + .../core/common_runtime/direct_session.cc | 144 ++++++++++++++---- .../core/common_runtime/direct_session.h | 22 ++- .../core/common_runtime/graph_runner.cc | 4 +- .../resource_variable_read_optimizer.cc | 9 +- .../simple_graph_execution_state.cc | 20 ++- .../simple_graph_execution_state.h | 20 ++- tensorflow/core/framework/function.cc | 15 +- tensorflow/core/framework/function.h | 1 + tensorflow/core/graph/subgraph.cc | 111 +++++++++----- tensorflow/core/graph/subgraph.h | 15 +- tensorflow/core/graph/subgraph_test.cc | 96 ++++++++++-- tensorflow/python/debug/lib/debug_data.py | 2 +- .../kernel_tests/control_flow_ops_py_test.py | 2 +- .../graph_transforms/fold_constants_lib.cc | 3 +- 16 files changed, 370 insertions(+), 100 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1b78b25ff51..d6143493877 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1563,6 +1563,7 @@ tf_cuda_library( ":lib_internal", ":proto_text", ":protos_all_cc", + "//tensorflow/core/kernels:function_ops", ], alwayslink = 1, ) diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h index c6d4bdad9c1..49566c8fa8f 100644 --- a/tensorflow/core/common_runtime/build_graph_options.h +++ b/tensorflow/core/common_runtime/build_graph_options.h @@ -30,6 +30,11 @@ struct BuildGraphOptions { // the former via "ref" fetch_endpoints. std::vector target_nodes; + // If `true`, uses Arg/Retval to implement feeds/fetches; otherwise + // uses Recv/Send to implement feeds/fetches. + // TODO(mrry): Remove this when the distributed runtime supports Arg/Retval. + bool use_function_convention = false; + string DebugString() const; }; diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index c05cceced11..002e246b80d 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -361,7 +361,6 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) { return Status::OK(); } -// TODO(yuanbyu): Simplify by treating Run() as "PRunSetup(); PRun()". Status DirectSession::Run(const NamedTensorList& inputs, const std::vector& output_names, const std::vector& target_nodes, @@ -426,13 +425,34 @@ Status DirectSession::Run(const RunOptions& run_options, executor_step_count, input_tensor_names, output_names, target_nodes)); } + // Configure a call frame for the step, which we use to feed and + // fetch values to and from the executors. + FunctionCallFrame call_frame(executors_and_keys->input_types, + executors_and_keys->output_types); + gtl::InlinedVector feed_args(inputs.size()); + for (const auto& it : inputs) { + if (it.second.dtype() == DT_RESOURCE) { + Tensor tensor_from_handle; + TF_RETURN_IF_ERROR( + ResourceHandleToInputTensor(it.second, &tensor_from_handle)); + feed_args[executors_and_keys->input_name_to_index[it.first]] = + tensor_from_handle; + } else { + feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second; + } + } + Status s = call_frame.SetArgs(feed_args); + if (errors::IsInternal(s)) { + return errors::InvalidArgument(s.error_message()); + } else if (!s.ok()) { + return s; + } + // Create a run state and start execution. RunState run_state(args.step_id, &devices_); run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); CancellationManager step_cancellation_manager; - - // Send inputs. - TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez)); + args.call_frame = &call_frame; // Start parallel Executors. const size_t num_executors = executors_and_keys->items.size(); @@ -535,8 +555,22 @@ Status DirectSession::Run(const RunOptions& run_options, } // Receive outputs. - TF_RETURN_IF_ERROR( - RecvOutputs(output_names, executors_and_keys, &run_state, outputs)); + if (outputs) { + std::vector sorted_outputs; + Status s = call_frame.ConsumeRetvals(&sorted_outputs); + if (errors::IsInternal(s)) { + return errors::InvalidArgument(s.error_message()); + } else if (!s.ok()) { + return s; + } + outputs->clear(); + outputs->reserve(sorted_outputs.size()); + for (const string& output_name : output_names) { + outputs->emplace_back( + std::move(sorted_outputs[executors_and_keys + ->output_name_to_index[output_name]])); + } + } // Save the output tensors of this run we choose to keep. TF_RETURN_IF_ERROR( @@ -706,11 +740,11 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, CheckFetch(inputs, output_names, executors_and_keys, run_state)); // Send inputs. - Status s = SendInputs(inputs, executors_and_keys, run_state->rendez); + Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez); // Receive outputs. if (s.ok()) { - s = RecvOutputs(output_names, executors_and_keys, run_state, outputs); + s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs); } // Save the output tensors of this run we choose to keep. @@ -770,16 +804,17 @@ Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor, } } -Status DirectSession::SendInputs(const NamedTensorList& inputs, - const ExecutorsAndKeys* executors_and_keys, - IntraProcessRendezvous* rendez) { +Status DirectSession::SendPRunInputs(const NamedTensorList& inputs, + const ExecutorsAndKeys* executors_and_keys, + IntraProcessRendezvous* rendez) { Status s; Rendezvous::ParsedKey parsed; // Insert the input tensors into the local rendezvous by their // rendezvous key. for (const auto& input : inputs) { - auto it = executors_and_keys->input_keys.find(input.first); - if (it == executors_and_keys->input_keys.end()) { + auto it = + executors_and_keys->input_name_to_rendezvous_key.find(input.first); + if (it == executors_and_keys->input_name_to_rendezvous_key.end()) { return errors::Internal("'", input.first, "' is not a pre-defined feed."); } const string& input_key = it->second; @@ -808,10 +843,10 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs, return Status::OK(); } -Status DirectSession::RecvOutputs(const std::vector& output_names, - const ExecutorsAndKeys* executors_and_keys, - RunState* run_state, - std::vector* outputs) { +Status DirectSession::RecvPRunOutputs( + const std::vector& output_names, + const ExecutorsAndKeys* executors_and_keys, RunState* run_state, + std::vector* outputs) { Status s; if (!output_names.empty()) { outputs->resize(output_names.size()); @@ -822,8 +857,9 @@ Status DirectSession::RecvOutputs(const std::vector& output_names, for (size_t output_offset = 0; output_offset < output_names.size(); ++output_offset) { const string& output_name = output_names[output_offset]; - auto it = executors_and_keys->output_keys.find(output_name); - if (it == executors_and_keys->output_keys.end()) { + auto it = + executors_and_keys->output_name_to_rendezvous_key.find(output_name); + if (it == executors_and_keys->output_name_to_rendezvous_key.end()) { return errors::Internal("'", output_name, "' is not a pre-defined fetch."); } @@ -987,14 +1023,16 @@ Status DirectSession::GetOrCreateExecutors( options.feed_endpoints = inputs_sorted; options.fetch_endpoints = outputs_sorted; options.target_nodes = tn_sorted; + options.use_function_convention = !run_state_args->is_partial_run; std::shared_ptr ek(new ExecutorsAndKeys); // The executor_lock_ is intentionally released while executor is // being created. std::unordered_map> graphs; - TF_RETURN_IF_ERROR( - CreateGraphs(options, &graphs, &ek->flib_def, run_state_args)); + TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &ek->flib_def, + run_state_args, &ek->input_types, + &ek->output_types)); if (run_state_args->is_partial_run) { ek->graph = std::move(run_state_args->graph); @@ -1079,17 +1117,37 @@ Status DirectSession::GetOrCreateExecutors( item->executor.reset(executor); } - // Compute the rendezvous keys to avoid recomputing them every time. - // - // We always use the first device as the device name portion of the - // key, even if we're feeding another graph. - for (const string& input : inputs) { - ek->input_keys[input] = GetRendezvousKey( - input, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); - } - for (const string& output : outputs) { - ek->output_keys[output] = GetRendezvousKey( - output, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); + // Cache the mapping from input/output names to graph elements to + // avoid recomputing it every time. + if (!run_state_args->is_partial_run) { + // For regular `Run()`, we use the function calling convention, and so + // maintain a mapping from input/output names to + // argument/return-value ordinal index. + for (size_t i = 0; i < inputs_sorted.size(); ++i) { + const string& input = inputs_sorted[i]; + ek->input_name_to_index[input] = i; + } + for (size_t i = 0; i < outputs_sorted.size(); ++i) { + const string& output = outputs_sorted[i]; + ek->output_name_to_index[output] = i; + } + } else { + // For `PRun()`, we use the rendezvous calling convention, and so + // maintain a mapping from input/output names to rendezvous keys. + // + // We always use the first device as the device name portion of the + // key, even if we're feeding another graph. + for (size_t i = 0; i < inputs_sorted.size(); ++i) { + const string& input = inputs_sorted[i]; + ek->input_name_to_rendezvous_key[input] = GetRendezvousKey( + input, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); + } + for (size_t i = 0; i < outputs_sorted.size(); ++i) { + const string& output = outputs_sorted[i]; + ek->output_name_to_rendezvous_key[output] = + GetRendezvousKey(output, device_set_.client_device()->attributes(), + FrameAndIter(0, 0)); + } } // Reacquire the lock, try to insert into the map. @@ -1110,7 +1168,8 @@ Status DirectSession::CreateGraphs( const BuildGraphOptions& subgraph_options, std::unordered_map>* outputs, std::unique_ptr* flib_def, - RunStateArgs* run_state_args) { + RunStateArgs* run_state_args, DataTypeVector* input_types, + DataTypeVector* output_types) { mutex_lock l(graph_def_lock_); std::unique_ptr client_graph; @@ -1135,6 +1194,23 @@ Status DirectSession::CreateGraphs( execution_state->BuildGraph(subgraph_options, &client_graph)); } + if (subgraph_options.feed_endpoints.size() != + client_graph->feed_types.size()) { + return errors::Internal( + "Graph pruning failed: requested number of feed endpoints = ", + subgraph_options.feed_endpoints.size(), + " versus number of pruned feed endpoints = ", + client_graph->feed_types.size()); + } + if (subgraph_options.fetch_endpoints.size() != + client_graph->fetch_types.size()) { + return errors::Internal( + "Graph pruning failed: requested number of fetch endpoints = ", + subgraph_options.fetch_endpoints.size(), + " versus number of pruned fetch endpoints = ", + client_graph->fetch_types.size()); + } + auto current_stateful_placements = execution_state->GetStatefulPlacements(); // Update our current state based on the execution_state's // placements. If there are any mismatches for a node, @@ -1240,6 +1316,8 @@ Status DirectSession::CreateGraphs( } } *flib_def = std::move(client_graph->flib_def); + std::swap(*input_types, client_graph->feed_types); + std::swap(*output_types, client_graph->fetch_types); return s; } diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index b9d22ac522c..848ef3bc62d 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -132,8 +132,13 @@ class DirectSession : public Session { NameNodeMap name_to_node; std::unique_ptr flib_def; std::vector items; - std::unordered_map input_keys; - std::unordered_map output_keys; + std::unordered_map input_name_to_index; + std::unordered_map input_name_to_rendezvous_key; + std::unordered_map output_name_to_index; + std::unordered_map output_name_to_rendezvous_key; + + DataTypeVector input_types; + DataTypeVector output_types; }; // For each live partial execution, the session maintains a RunState. @@ -187,7 +192,8 @@ class DirectSession : public Session { const BuildGraphOptions& options, std::unordered_map>* outputs, std::unique_ptr* flib_def, - RunStateArgs* run_state_args); + RunStateArgs* run_state_args, DataTypeVector* input_types, + DataTypeVector* output_types); ::tensorflow::Status ExtendLocked(const GraphDef& graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_); @@ -196,17 +202,17 @@ class DirectSession : public Session { const Tensor& resource_tensor, Tensor* retrieved_tensor); // Feeds more inputs to the executors, triggering further execution. - ::tensorflow::Status SendInputs( + ::tensorflow::Status SendPRunInputs( const std::vector>& inputs, const ExecutorsAndKeys* executors_and_keys, IntraProcessRendezvous* rendez); // Fetches more outputs from the executors. It waits until the output // tensors are computed. - ::tensorflow::Status RecvOutputs(const std::vector& output_names, - const ExecutorsAndKeys* executors_and_keys, - RunState* run_state, - std::vector* outputs); + ::tensorflow::Status RecvPRunOutputs( + const std::vector& output_names, + const ExecutorsAndKeys* executors_and_keys, RunState* run_state, + std::vector* outputs); // Check if the specified fetches can be computed from the feeds // that we have already provided. diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 514a63590b1..a85fbbf88ff 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -130,9 +130,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, } // Call RewriteGraphForExecution + subgraph::RewriteGraphMetadata metadata; TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( graph_to_run.get(), input_names, output_names, {} /* target nodes */, - cpu_device_->attributes())); + cpu_device_->attributes(), false /* use_function_convention */, + &metadata)); // Create the local executor and the Rendezvous for fetching back the // constants. diff --git a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc index 85a29e11e23..c179e94c36b 100644 --- a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc +++ b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc @@ -21,9 +21,9 @@ limitations under the License. namespace tensorflow { namespace { -// Replaces ReadVariableOp nodes which are only used by Sends and sinks with -// _UnsafeReadVariable nodes, as this transforamtion is safe and will improve -// performance. +// Replaces ReadVariableOp nodes which are only used by Sends, sinks, +// and function Retvals with _UnsafeReadVariable nodes, as this +// transformation is safe and will improve performance. class ResourceVariableReadPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override { @@ -43,7 +43,8 @@ class ResourceVariableReadPass : public GraphOptimizationPass { if (n->type_string() == "ReadVariableOp") { bool skip = false; for (const Edge* e : n->out_edges()) { - if (!e->dst()->IsSend() && e->dst()->name() != "_SINK") { + if (!e->dst()->IsSend() && e->dst()->type_string() != "_Retval" && + e->dst()->name() != "_SINK") { skip = true; } } diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index c2ac15b345d..31e63a9ef75 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -284,9 +284,11 @@ Status SimpleGraphExecutionState::InitBaseGraph( if (session_options_ && session_options_->config.graph_options().place_pruned_graph()) { // Rewrite the graph before placement. + rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata); TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( new_graph.get(), options.feed_endpoints, options.fetch_endpoints, - options.target_nodes, device_set_->client_device()->attributes())); + options.target_nodes, device_set_->client_device()->attributes(), + options.use_function_convention, rewrite_metadata_.get())); } // Save stateful placements before placing. @@ -333,15 +335,26 @@ Status SimpleGraphExecutionState::BuildGraph( std::unique_ptr ng(new Graph(flib_def_.get())); CopyGraph(*graph_, ng.get()); + subgraph::RewriteGraphMetadata rewrite_metadata; if (session_options_ == nullptr || !session_options_->config.graph_options().place_pruned_graph()) { // Extract the subset of the graph that needs to be run, adding feed/fetch // ops as needed. TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( ng.get(), options.feed_endpoints, options.fetch_endpoints, - options.target_nodes, device_set_->client_device()->attributes())); + options.target_nodes, device_set_->client_device()->attributes(), + options.use_function_convention, &rewrite_metadata)); + } else { + // This SimpleGraphExecutionState represents a graph that was + // pruned when this was constructed, so we copy the metadata from + // a member variable. + CHECK(rewrite_metadata_); + rewrite_metadata = *rewrite_metadata_; } + CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size()); + CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size()); + // Make a fresh copy of the function library for the client graph. std::unique_ptr flib( new FunctionLibraryDefinition(*flib_def_)); @@ -363,7 +376,8 @@ Status SimpleGraphExecutionState::BuildGraph( // since the local CostModel used to record its stats is sized by // the largest node id. std::unique_ptr dense_copy( - new SimpleClientGraph(std::move(flib))); + new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types, + rewrite_metadata.fetch_types)); CopyGraph(*ng, &dense_copy->graph); // TODO(vrv): We should check invariants of the graph here. diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h index 3b6ce23c754..00b5509fd78 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.h +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h @@ -39,6 +39,10 @@ struct SessionOptions; class StepStats; class Timeline; +namespace subgraph { +struct RewriteGraphMetadata; +} + struct SimpleGraphExecutionStateOptions { const DeviceSet* device_set = nullptr; const SessionOptions* session_options = nullptr; @@ -50,13 +54,19 @@ struct SimpleGraphExecutionStateOptions { // A SimpleClientGraph is simply a sub-graph of the full graph as induced by // BuildGraphOptions. struct SimpleClientGraph { - explicit SimpleClientGraph(std::unique_ptr flib) - : flib_def(std::move(flib)), graph(flib_def.get()) {} + explicit SimpleClientGraph(std::unique_ptr flib, + DataTypeVector feed_types, + DataTypeVector fetch_types) + : flib_def(std::move(flib)), + graph(flib_def.get()), + feed_types(std::move(feed_types)), + fetch_types(std::move(fetch_types)) {} // Each client-graph gets its own function library since optimization passes // post rewrite for execution might want to introduce new functions. std::unique_ptr flib_def; Graph graph; - int32 placement_version; + DataTypeVector feed_types; + DataTypeVector fetch_types; }; // SimpleGraphExecutionState is responsible for generating an @@ -190,6 +200,10 @@ class SimpleGraphExecutionState { // and may be updated by a graph optimization pass. std::unique_ptr flib_def_; + // `rewrite_metadata_` is only set for SimpleGraphExecutionState + // objects created by `MakeForPrunedGraph()`. + std::unique_ptr rewrite_metadata_; + // The dataflow graph owned by this object. Graph* graph_; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index edb52737d94..8a7d96c38a9 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -789,7 +789,7 @@ Status FunctionCallFrame::GetRetvals(std::vector* rets) const { rets->clear(); rets->reserve(rets_.size()); for (size_t i = 0; i < rets_.size(); ++i) { - auto item = rets_[i]; + const auto& item = rets_[i]; if (item.has_val) { rets->push_back(item.val); } else { @@ -799,6 +799,19 @@ Status FunctionCallFrame::GetRetvals(std::vector* rets) const { return Status::OK(); } +Status FunctionCallFrame::ConsumeRetvals(std::vector* rets) { + rets->clear(); + rets->reserve(rets_.size()); + for (size_t i = 0; i < rets_.size(); ++i) { + if (rets_[i].has_val) { + rets->emplace_back(std::move(rets_[i].val)); + } else { + return errors::Internal("Retval[", i, "] does not have value"); + } + } + return Status::OK(); +} + Status FunctionCallFrame::GetArg(int index, Tensor* val) const { if (index < 0 || static_cast(index) >= args_.size()) { return errors::InvalidArgument("GetArg ", index, " is not within [0, ", diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 63c868ac9b8..210e5b949a5 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -259,6 +259,7 @@ class FunctionCallFrame { // Caller methods. Status SetArgs(gtl::ArraySlice args); Status GetRetvals(std::vector* rets) const; + Status ConsumeRetvals(std::vector* rets); // Callee methods. Status GetArg(int index, Tensor* val) const; diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 91292500e1e..9849d9a1596 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -55,8 +55,13 @@ namespace { // state). static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, const gtl::ArraySlice& fed_outputs, - subgraph::NameIndex* name_index) { - for (const string& t : fed_outputs) { + bool use_function_convention, + subgraph::NameIndex* name_index, + DataTypeVector* out_feed_types) { + out_feed_types->clear(); + out_feed_types->reserve(fed_outputs.size()); + for (size_t i = 0; i < fed_outputs.size(); ++i) { + const string& t = fed_outputs[i]; TensorId id(ParseTensorName(t)); auto iter = name_index->find(id.first); @@ -71,17 +76,31 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, } Node* recv_node; - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second), - "_Recv") - .Attr("tensor_type", BaseType(n->output_type(id.second))) - .Attr("tensor_name", t) - .Attr("send_device", device_info.name()) - .Attr("recv_device", device_info.name()) - .Attr("send_device_incarnation", - static_cast(device_info.incarnation())) - .Attr("client_terminated", true) - .Finalize(g, &recv_node)); + + if (!use_function_convention) { + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second), + "_Recv") + .Attr("tensor_type", BaseType(n->output_type(id.second))) + .Attr("tensor_name", t) + .Attr("send_device", device_info.name()) + .Attr("recv_device", device_info.name()) + .Attr("send_device_incarnation", + static_cast(device_info.incarnation())) + .Attr("client_terminated", true) + .Finalize(g, &recv_node)); + } else { + // NOTE(mrry): We must include the index as part of the node + // name, because _Arg is a "stateful" kernel and therefore + // its name must uniquely identify a kernel instance across all + // graphs in the same session. + TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_arg_", id.first, "_", + id.second, "_", i), + "_Arg") + .Attr("T", BaseType(n->output_type(id.second))) + .Attr("index", static_cast(i)) + .Finalize(g, &recv_node)); + } recv_node->set_assigned_device_name(device_info.name()); // Copy the _output_shapes from the original node to the feed node, @@ -130,6 +149,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, } g->RemoveEdge(e); } + out_feed_types->push_back(BaseType(n->output_type(id.second))); } return Status::OK(); } @@ -181,9 +201,14 @@ namespace subgraph { Status FetchOutputs(Graph* g, const DeviceAttributes& device_info, const gtl::ArraySlice& fetch_outputs, - NameIndex* name_index, std::vector* fetch_nodes) { - fetch_nodes->clear(); - for (const string& t : fetch_outputs) { + bool use_function_convention, NameIndex* name_index, + std::vector* out_fetch_nodes, + DataTypeVector* out_fetch_types) { + out_fetch_nodes->clear(); + out_fetch_nodes->reserve(fetch_outputs.size()); + for (size_t i = 0; i < fetch_outputs.size(); ++i) { + const string& t = fetch_outputs[i]; + // Parse t into node_name and output_index. TensorId id(ParseTensorName(t)); @@ -213,25 +238,39 @@ Status FetchOutputs(Graph* g, const DeviceAttributes& device_info, // Create the fetch Node and connect it up Node* send_node; - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second), - "_Send") - .Input(n, id.second) - .Attr("tensor_name", t) - .Attr("send_device", device_info.name()) - .Attr("recv_device", device_info.name()) - .Attr("send_device_incarnation", - static_cast(device_info.incarnation())) - .Attr("client_terminated", true) - .Finalize(g, &send_node)); + if (!use_function_convention) { + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second), + "_Send") + .Input(n, id.second) + .Attr("tensor_name", t) + .Attr("send_device", device_info.name()) + .Attr("recv_device", device_info.name()) + .Attr("send_device_incarnation", + static_cast(device_info.incarnation())) + .Attr("client_terminated", true) + .Finalize(g, &send_node)); + } else { + // NOTE(mrry): We must include the index as part of the node + // name, because _Retval is a "stateful" kernel and therefore + // its name must uniquely identify a kernel instance across all + // graphs in the same session. + TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_retval_", id.first, "_", + id.second, "_", i), + "_Retval") + .Input(n, id.second) + .Attr("T", BaseType(n->output_type(id.second))) + .Attr("index", static_cast(i)) + .Finalize(g, &send_node)); + } send_node->set_assigned_device_name(device_info.name()); - VLOG(1) << "Created fetch node: " << SummarizeNodeDef(send_node->def()); // Update the index. (*name_index)[send_node->name()] = send_node; g->AddControlEdge(send_node, g->sink_node()); - fetch_nodes->push_back(send_node); + out_fetch_nodes->push_back(send_node); + out_fetch_types->push_back(BaseType(n->output_type(id.second))); } return Status::OK(); @@ -241,7 +280,8 @@ Status RewriteGraphForExecution( Graph* g, const gtl::ArraySlice& fed_outputs, const gtl::ArraySlice& fetch_outputs, const gtl::ArraySlice& target_node_names, - const DeviceAttributes& device_info) { + const DeviceAttributes& device_info, bool use_function_convention, + RewriteGraphMetadata* out_metadata) { if (fetch_outputs.empty() && target_node_names.empty()) { return errors::InvalidArgument( "Must specify at least one target to fetch or execute."); @@ -274,18 +314,21 @@ Status RewriteGraphForExecution( // currently listed in "fetch_nodes". We pass "name_index" so the index is // kept up to date. if (!fed_outputs.empty()) { - TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, &name_index)); + TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, + use_function_convention, &name_index, + &out_metadata->feed_types)); } // Add the fetch nodes, also updating "name_index". std::vector fetch_nodes; if (!fetch_outputs.empty()) { - TF_RETURN_IF_ERROR( - FetchOutputs(g, device_info, fetch_outputs, &name_index, &fetch_nodes)); + TF_RETURN_IF_ERROR(FetchOutputs(g, device_info, fetch_outputs, + use_function_convention, &name_index, + &fetch_nodes, &out_metadata->fetch_types)); } // Prune the graph to only compute what is needed for the fetch nodes and the - // targets nodes. + // target nodes. if (!fetch_nodes.empty() || !target_node_names.empty()) { TF_RETURN_IF_ERROR( PruneForTargets(g, name_index, fetch_nodes, target_node_names)); diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h index d94d983d000..8ccc27914bc 100644 --- a/tensorflow/core/graph/subgraph.h +++ b/tensorflow/core/graph/subgraph.h @@ -26,6 +26,18 @@ limitations under the License. namespace tensorflow { namespace subgraph { +// Information about a graph rewritten by `RewriteGraphForExecution()`. +struct RewriteGraphMetadata { + // The element type of each tensor fed to this subgraph. The order + // of types corresponds to the order of tensor names in + // `fed_outputs` when calling `RewriteGraphForExecution()`. + DataTypeVector feed_types; + // The element type of each tensor fetched from this subgraph. The + // order of types corresponds to the order of tensor names in + // `fetch_outputs` when calling `RewriteGraphForExecution()`. + DataTypeVector fetch_types; +}; + // Rewrite the graph structure of "*g" to deal with feeding node // outputs, fetching node outputs, and only running a subset of the // graph. "fed_outputs" and "fetch_outputs" are both lists of @@ -56,7 +68,8 @@ Status RewriteGraphForExecution( Graph* g, const gtl::ArraySlice& fed_outputs, const gtl::ArraySlice& fetch_outputs, const gtl::ArraySlice& target_node_names, - const DeviceAttributes& device_info); + const DeviceAttributes& device_info, bool use_function_convention, + RewriteGraphMetadata* out_metadata); typedef std::unordered_map NameIndex; diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index ee4960121f5..3dc11b7a166 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -104,7 +104,8 @@ class SubgraphTest : public ::testing::Test { } string Subgraph(const string& fed_str, const string& fetch_str, - const string& targets_str) { + const string& targets_str, + bool use_function_convention = false) { Graph* subgraph = new Graph(OpRegistry::Global()); CopyGraph(*g_, subgraph); std::vector fed = @@ -114,13 +115,18 @@ class SubgraphTest : public ::testing::Test { std::vector targets = str_util::Split(targets_str, ',', str_util::SkipEmpty()); - Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch, targets, - device_info_); + subgraph::RewriteGraphMetadata metadata; + Status s = subgraph::RewriteGraphForExecution( + subgraph, fed, fetch, targets, device_info_, use_function_convention, + &metadata); if (!s.ok()) { delete subgraph; return s.ToString(); } + EXPECT_EQ(fed.size(), metadata.feed_types.size()); + EXPECT_EQ(fetch.size(), metadata.fetch_types.size()); + // Replace the graph with the subgraph for the rest of the display program g_.reset(subgraph); return "OK"; @@ -178,6 +184,20 @@ TEST_F(SubgraphTest, FedOutputs1) { ExpectNodes("W1,W2,_recv_input_1,t1,t2"); } +TEST_F(SubgraphTest, FedOutputs1_FunctionConvention) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", + Subgraph("input:1", "", "t2", true /* use_function_convention */)); + ExpectNodes("W1,W2,_arg_input_1_0,t1,t2"); +} + TEST_F(SubgraphTest, FedRefNode) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" @@ -189,7 +209,19 @@ TEST_F(SubgraphTest, FedRefNode) { EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0))); } -TEST_F(SubgraphTest, FedOutputs2) { +TEST_F(SubgraphTest, FedRefNode_FunctionConvention) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }"); + EXPECT_EQ("OK", + Subgraph("W1:0", "", "t1", true /* use_function_convention */)); + ExpectNodes("_arg_W1_0_0,W2,t1"); + Node* n = FindNode("_arg_W1_0_0"); + EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0))); +} + +TEST_F(SubgraphTest, FedOutputs2_FunctionConvention) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'W2' op: 'TestParams' }" @@ -200,8 +232,9 @@ TEST_F(SubgraphTest, FedOutputs2) { "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); // We feed input:1, but nothing connects to it, so the _recv(input:1) // node also disappears. - EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2")); - ExpectNodes("_recv_t1_0,_recv_W2_0,t2"); + EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2", + true /* use_function_convention */)); + ExpectNodes("_arg_t1_0_1,_arg_W2_0_2,t2"); } TEST_F(SubgraphTest, FetchOutputs1) { @@ -218,6 +251,22 @@ TEST_F(SubgraphTest, FetchOutputs1) { "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0"); } +TEST_F(SubgraphTest, FetchOutputs1_FunctionConvention) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2", + true /* use_function_convention */)); + ExpectNodes( + "W1,W2,input,t1,t2,_retval_W2_0_0,_retval_input_1_1,_retval_t1_0_2,_" + "retval_t2_0_3"); +} + TEST_F(SubgraphTest, FetchOutputs2) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" @@ -231,6 +280,20 @@ TEST_F(SubgraphTest, FetchOutputs2) { ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0"); } +TEST_F(SubgraphTest, FetchOutputs2_FunctionConvention) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", + Subgraph("", "t3_a", "t2", true /* use_function_convention */)); + ExpectNodes("W1,W2,input,t1,t2,t3_a,_retval_t3_a_0_0"); +} + TEST_F(SubgraphTest, ChainOfFools) { ExpectOK( "node { name: 'a' op: 'TestParams' }" @@ -315,7 +378,8 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) { REGISTER_OP("In").Output("o: float"); REGISTER_OP("Op").Input("i: float").Output("o: float"); -static void BM_Subgraph(int iters, int num_nodes) { +static void BM_SubgraphHelper(int iters, int num_nodes, + bool use_function_convention) { DeviceAttributes device_info; device_info.set_name("/job:a/replica:0/task:0/cpu:0"); device_info.set_device_type(DeviceType(DEVICE_CPU).type()); @@ -347,12 +411,26 @@ static void BM_Subgraph(int iters, int num_nodes) { while (--iters > 0) { Graph* subgraph = new Graph(OpRegistry::Global()); CopyGraph(g, subgraph); - TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch, - targets, device_info)); + subgraph::RewriteGraphMetadata metadata; + TF_CHECK_OK(subgraph::RewriteGraphForExecution( + subgraph, fed, fetch, targets, device_info, use_function_convention, + &metadata)); delete subgraph; } } + +static void BM_Subgraph(int iters, int num_nodes) { + BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */); +} +static void BM_SubgraphFunctionConvention(int iters, int num_nodes) { + BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */); +} BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000); +BENCHMARK(BM_SubgraphFunctionConvention) + ->Arg(100) + ->Arg(1000) + ->Arg(10000) + ->Arg(100000); } // namespace } // namespace tensorflow diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index a76dd4f6d60..bb457a01b23 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -820,7 +820,7 @@ class DebugDumpDir(object): self._node_op_types[node.name] = node.op for inp in node.input: - if is_copy_node(inp) and node.op == "_Send": + if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"): self._copy_send_nodes.append(node.name) if inp.startswith("^"): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 6c7cbbff9cb..00f6cc0d6d9 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -196,7 +196,7 @@ class ControlFlowTest(test.TestCase): with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, - lambda e: "The tensor returned for" in str(e)): + lambda e: "Retval[0] does not have value" in str(e)): dead_branch.eval() def testSwitchMergeLess(self): diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc index 8d1f19bf30b..466e61b42dc 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc @@ -147,9 +147,10 @@ Status FoldConstants(const GraphDef& input_graph_def, TF_RETURN_IF_ERROR( ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr)); DeviceAttributes device_attributes; + subgraph::RewriteGraphMetadata metadata; TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( &input_graph, context.input_names, context.output_names, {}, - device_attributes)); + device_attributes, false /* use_function_convention */, &metadata)); bool was_mutated; TF_RETURN_IF_ERROR(DoConstantFoldingWithStatus( ConstantFoldingOptions(), nullptr, Env::Default(), nullptr, &input_graph,