Use the zero-copy implementation of GraphConstructor in more places.

Many uses of GraphConstructor take a `const GraphDef&` to a locally-defined GraphDef that is subsequently destroyed. We can move the GraphDef into GraphConstructor to avoid copying the graph nodes repeatedly. In some cases with large GraphDefs (e.g. with large embedded constant tensors) this optimization will reduce peak memory consumption.

PiperOrigin-RevId: 259809688
This commit is contained in:
Derek Murray 2019-07-24 13:50:36 -07:00 committed by TensorFlower Gardener
parent 3a72de3a1b
commit 5d37c2b785
11 changed files with 27 additions and 26 deletions

View File

@ -318,7 +318,7 @@ Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const {
if (ok()) { if (ok()) {
GraphDef graph_def; GraphDef graph_def;
graph()->ToGraphDef(&graph_def); graph()->ToGraphDef(&graph_def);
UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g)); UpdateStatus(ConvertGraphDefToGraph(opts, std::move(graph_def), g));
} }
return *impl()->status_; return *impl()->status_;
} }

View File

@ -300,8 +300,8 @@ Status Importer::RemoveBackedges(const Graph& graph) {
graph_ = absl::make_unique<Graph>(graph.flib_def()); graph_ = absl::make_unique<Graph>(graph.flib_def());
GraphConstructorOptions opts; GraphConstructorOptions opts;
opts.allow_internal_ops = true; opts.allow_internal_ops = true;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph(
::tensorflow::ConvertGraphDefToGraph(opts, graph_def, graph_.get())); opts, std::move(graph_def), graph_.get()));
// Remove all the backedges. So the nodes can be added to the shape refiner. // Remove all the backedges. So the nodes can be added to the shape refiner.
TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get())); TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get()));
@ -1394,8 +1394,8 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
if (add_default_attributes) { if (add_default_attributes) {
TF_RETURN_IF_ERROR(AddDefaultsToNodeDef(&preprocessed_graphdef)); TF_RETURN_IF_ERROR(AddDefaultsToNodeDef(&preprocessed_graphdef));
} }
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph)); options, std::move(preprocessed_graphdef), &graph));
return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs, return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs,
context); context);

View File

@ -384,8 +384,8 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
&second_copy_def, *g->op_registry(), /*node_offset=*/0)); &second_copy_def, *g->op_registry(), /*node_offset=*/0));
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
second_copy_def, g.get())); GraphConstructorOptions(), std::move(second_copy_def), g.get()));
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
// Functionalize control flow. // Functionalize control flow.

View File

@ -1614,15 +1614,15 @@ Status DirectSession::CreateGraphs(
} }
} }
for (const auto& partition : partitions) { for (auto& partition : partitions) {
std::unique_ptr<Graph> device_graph( std::unique_ptr<Graph> device_graph(
new Graph(client_graph->flib_def.get())); new Graph(client_graph->flib_def.get()));
GraphConstructorOptions device_opts; GraphConstructorOptions device_opts;
// There are internal operations (e.g., send/recv) that we now allow. // There are internal operations (e.g., send/recv) that we now allow.
device_opts.allow_internal_ops = true; device_opts.allow_internal_ops = true;
device_opts.expect_device_spec = true; device_opts.expect_device_spec = true;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
device_graph.get())); device_opts, std::move(partition.second), device_graph.get()));
outputs->emplace(partition.first, std::move(device_graph)); outputs->emplace(partition.first, std::move(device_graph));
} }

View File

@ -757,8 +757,8 @@ Status GraphExecutionState::OptimizeGraph(
GraphConstructorOptions opts; GraphConstructorOptions opts;
opts.allow_internal_ops = true; opts.allow_internal_ops = true;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(new_graph),
ConvertGraphDefToGraph(opts, new_graph, optimized_graph->get())); optimized_graph->get()));
// The graph conversion sets the requested device names but not the // The graph conversion sets the requested device names but not the
// assigned device names. However, since at this point the graph is placed // assigned device names. However, since at this point the graph is placed
// TF expects an assigned device name for every node. Therefore we copy // TF expects an assigned device name for every node. Therefore we copy

View File

@ -179,14 +179,14 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
} }
std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs; std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
for (const auto& partition : partitions) { for (auto& partition : partitions) {
std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global())); std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
GraphConstructorOptions device_opts; GraphConstructorOptions device_opts;
// There are internal operations (e.g., send/recv) that we now allow. // There are internal operations (e.g., send/recv) that we now allow.
device_opts.allow_internal_ops = true; device_opts.allow_internal_ops = true;
device_opts.expect_device_spec = true; device_opts.expect_device_spec = true;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
device_graph.get())); device_opts, std::move(partition.second), device_graph.get()));
partition_graphs.emplace(partition.first, std::move(device_graph)); partition_graphs.emplace(partition.first, std::move(device_graph));
} }

View File

@ -22,7 +22,7 @@ Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph) {
GraphDef graph_def; GraphDef graph_def;
TF_RETURN_IF_ERROR(builder.ToGraphDef(&graph_def)); TF_RETURN_IF_ERROR(builder.ToGraphDef(&graph_def));
GraphConstructorOptions opts; GraphConstructorOptions opts;
return ConvertGraphDefToGraph(opts, graph_def, graph); return ConvertGraphDefToGraph(opts, std::move(graph_def), graph);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -267,8 +267,8 @@ Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
graph_ctor_opts.expect_device_spec = false; graph_ctor_opts.expect_device_spec = false;
std::unique_ptr<Graph> graphptr(new Graph(function_library)); std::unique_ptr<Graph> graphptr(new Graph(function_library));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
ConvertGraphDefToGraph(graph_ctor_opts, graph_def, graphptr.get())); graph_ctor_opts, std::move(graph_def), graphptr.get()));
// Optimize the graph. // Optimize the graph.
::tensorflow::GraphOptimizer optimizer(*optimizer_opts); ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);

View File

@ -784,7 +784,7 @@ constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode; using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource; using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
// Checks if boolean attribute is defined and it's value is 'true'. // Checks if boolean attribute is defined and its value is 'true'.
bool CheckBoolAttr(const Node* n, absl::string_view attr_name) { bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
bool match; bool match;
Status s = GetNodeAttr(n->attrs(), attr_name, &match); Status s = GetNodeAttr(n->attrs(), attr_name, &match);

View File

@ -802,8 +802,6 @@ Status OptimizeGraph(
std::unique_ptr<tensorflow::Graph> optimized_graph( std::unique_ptr<tensorflow::Graph> optimized_graph(
new tensorflow::Graph(OpRegistry::Global())); new tensorflow::Graph(OpRegistry::Global()));
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
out_graph, optimized_graph.get()));
// Copy optimized functions back to the overlay lib. // Copy optimized functions back to the overlay lib.
if (flib) { if (flib) {
@ -817,25 +815,28 @@ Status OptimizeGraph(
} }
} }
*g = std::move(optimized_graph); TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
GraphConstructorOptions(), std::move(out_graph), optimized_graph.get()));
// The graph conversion sets the requested device names but not the // The graph conversion sets the requested device names but not the
// assigned device names. However, since at this point the graph is // assigned device names. However, since at this point the graph is
// placed TF expects an assigned device name for every node. Therefore // placed TF expects an assigned device name for every node. Therefore
// we copy the requested device into the assigned device field. // we copy the requested device into the assigned device field.
for (Node* node : (*g)->nodes()) { for (Node* node : optimized_graph->nodes()) {
if (node->IsOp() && node->assigned_device_name().empty()) { if (node->IsOp() && node->assigned_device_name().empty()) {
if (node->requested_device().empty()) { if (node->requested_device().empty()) {
return errors::Internal( return errors::Internal(
"Either placer did not place the node or Grappler did not " "Either placer did not place the node or Grappler did not "
"copy the assigned device. Contact Grappler team since latter " "copy the assigned device. Contact Grappler team since latter "
"is more likely. Node=", "is more likely. Node=",
node->name(), " Graph: ", (*g)->ToGraphDefDebug().DebugString()); node->name(),
" Graph: ", optimized_graph->ToGraphDefDebug().DebugString());
} }
node->set_assigned_device_name(node->requested_device()); node->set_assigned_device_name(node->requested_device());
} }
} }
*g = std::move(optimized_graph);
return Status::OK(); return Status::OK();
} }

View File

@ -111,8 +111,8 @@ Status OptimizationPassRunner::Run(absl::string_view pass_to_run,
GraphConstructorOptions graph_opts; GraphConstructorOptions graph_opts;
graph_opts.expect_device_spec = true; graph_opts.expect_device_spec = true;
graph_opts.allow_internal_ops = true; graph_opts.allow_internal_ops = true;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_opts, std::move(input),
ConvertGraphDefToGraph(graph_opts, input, options.graph->get())); options.graph->get()));
// Add all devices that were previously configured with AddDevice. // Add all devices that were previously configured with AddDevice.
DeviceSet device_set; DeviceSet device_set;