From 2b5011625bcab6c50c51b948e68063393711bd30 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 7 Sep 2017 19:03:55 -0700 Subject: [PATCH] Pass in the CPU device to grappler, instead of making a new one, when possible. PiperOrigin-RevId: 167944850 --- .../simple_graph_execution_state.cc | 15 ++++++++-- tensorflow/core/grappler/optimizers/BUILD | 1 + .../grappler/optimizers/constant_folding.cc | 13 +++++---- .../grappler/optimizers/constant_folding.h | 7 +++-- .../optimizers/constant_folding_test.cc | 28 +++++++++---------- .../grappler/optimizers/meta_optimizer.cc | 9 +++--- .../core/grappler/optimizers/meta_optimizer.h | 14 ++++++++-- tensorflow/python/grappler/tf_optimizer.i | 4 ++- 8 files changed, 61 insertions(+), 30 deletions(-) diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index 363d3a0c9d3..c66dc568f63 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/util.h" #ifndef IS_MOBILE_PLATFORM @@ -338,14 +339,24 @@ Status SimpleGraphExecutionState::OptimizeGraph( } std::unordered_map device_map; + Device* cpu_device = nullptr; for (const auto& device : device_set_->devices()) { device_map[device->name()] = grappler::GetDeviceInfo(device->parsed_name()); + if (device->parsed_name().id == 0 && + StringPiece(device->parsed_name().type) == "CPU" && + device->GetAllocator(AllocatorAttributes()) != nullptr) { + cpu_device = device; + } + } + if (cpu_device == nullptr) { + return errors::Internal( + "Unable to find CPU device needed for constant folding"); } grappler::VirtualCluster cluster(device_map); GraphDef new_graph; - TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(item, rewrite_options, - &cluster, &new_graph)); + TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer( + item, rewrite_options, cpu_device, &cluster, &new_graph)); GraphConstructorOptions opts; opts.allow_internal_ops = true; optimized_graph->reset(new Graph(OpRegistry::Global())); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 659451e9913..c16ca0d9c4c 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -310,6 +310,7 @@ cc_library( ":layout_optimizer", ":memory_optimizer", ":model_pruner", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 443c0b72abc..1078711e6b9 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -95,8 +95,8 @@ class DeviceSimple : public DeviceBase { }; } // namespace - -ConstantFolding::ConstantFolding() { +ConstantFolding::ConstantFolding(DeviceBase* cpu_device) + : cpu_device_(cpu_device) { resource_mgr_.reset(new ResourceMgr()); } @@ -381,11 +381,11 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node, TensorVector* output) const { Status status; auto op_kernel = - CreateOpKernel("CPU", device_.get(), device_->GetAllocator({}), node, + CreateOpKernel("CPU", cpu_device_, cpu_device_->GetAllocator({}), node, TF_GRAPH_DEF_VERSION, &status); TF_RETURN_IF_ERROR(status); OpKernelContext::Params params; - params.device = device_.get(); + params.device = cpu_device_; params.frame_iter = FrameAndIter(0, 0); params.inputs = &inputs; params.op_kernel = op_kernel.get(); @@ -845,8 +845,11 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, graph_ = item.graph; node_map_.reset(new NodeMap(&graph_)); nodes_to_preserve_ = item.NodesToPreserve(); - device_.reset(new DeviceSimple()); *output = GraphDef(); + if (cpu_device_ == nullptr) { + owned_device_.reset(new DeviceSimple()); + cpu_device_ = owned_device_.get(); + } bool has_feed = !item.feed.empty(); has_fetch_ = !item.fetch.empty(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 0c1c40dfd34..0203ff42963 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -32,7 +32,7 @@ const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl"; // Constant folding optimization for a graph. class ConstantFolding : public GraphOptimizer { public: - ConstantFolding(); + ConstantFolding(DeviceBase* cpu_device); ~ConstantFolding() override {} @@ -69,7 +69,10 @@ class ConstantFolding : public GraphOptimizer { const GraphProperties& properties) const; Status SimplifyGraph(GraphDef* output, const GraphProperties& properties); - std::unique_ptr device_; + // Points to an externally provided device or to owned_device_; + DeviceBase* cpu_device_; + std::unique_ptr owned_device_; + std::unique_ptr resource_mgr_; GraphDef graph_; std::unique_ptr node_map_; diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 0f7e7f1d494..a64267cc384 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -57,7 +57,7 @@ TEST_F(ConstantFoldingTest, SimpleFolding) { item.fetch.push_back("d"); TF_CHECK_OK(s.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -103,7 +103,7 @@ TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) { item.fetch.push_back("f"); TF_CHECK_OK(s.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -152,7 +152,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) { item.fetch.push_back("e"); TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -195,7 +195,7 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) { GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -252,7 +252,7 @@ TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) { item.fetch.push_back("i2"); TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -326,7 +326,7 @@ TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) { } item.fetch = outputs; - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -364,7 +364,7 @@ TEST_F(ConstantFoldingTest, ShapeMaterialization) { item.fetch.push_back("p2"); TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -399,7 +399,7 @@ TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) { GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -461,7 +461,7 @@ TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) { GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -537,7 +537,7 @@ TEST_F(ConstantFoldingTest, SwitchNodes) { TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -605,7 +605,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) { item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"}; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -679,7 +679,7 @@ TEST_F(ConstantFoldingTest, NoOpReduction) { item.fetch.push_back("s"); TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -738,7 +738,7 @@ TEST_F(ConstantFoldingTest, NoOpReshape) { item.fetch = {"s1", "s2", "s3", "s4"}; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -785,7 +785,7 @@ TEST_F(ConstantFoldingTest, Packing) { GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - ConstantFolding fold; + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 4ac985b41b4..6718d2d7392 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -37,7 +37,7 @@ std::unique_ptr MetaOptimizer::NewOptimizer( graph_optimizer.reset(new ModelPruner()); } if (optimizer == "constfold") { - graph_optimizer.reset(new ConstantFolding()); + graph_optimizer.reset(new ConstantFolding(cpu_device_)); } if (optimizer == "layout") { graph_optimizer.reset(new LayoutOptimizer()); @@ -64,7 +64,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } if (cfg_.constant_folding() != RewriterConfig::OFF) { optimizers.push_back( - std::unique_ptr(new ConstantFolding())); + std::unique_ptr(new ConstantFolding(cpu_device_))); } if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) { optimizers.push_back( @@ -144,8 +144,9 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) { } Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg, - Cluster* cluster, GraphDef* optimized_graph) { - MetaOptimizer optimizer(cfg); + DeviceBase* cpu_device, Cluster* cluster, + GraphDef* optimized_graph) { + MetaOptimizer optimizer(cpu_device, cfg); return optimizer.Optimize(cluster, item, optimized_graph); } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 6b950c973d9..b00886b964b 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_ #define TENSORFLOW_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_ +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/lib/core/status.h" @@ -27,7 +28,8 @@ namespace grappler { // Run the other grappler optimizers based on the specified rewriter config. class MetaOptimizer : public GraphOptimizer { public: - MetaOptimizer(const RewriterConfig& cfg) : cfg_(cfg) {} + MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg) + : cpu_device_(cpu_device), cfg_(cfg) {} ~MetaOptimizer() override {} string name() const override { return "meta_optimizer"; }; @@ -40,13 +42,21 @@ class MetaOptimizer : public GraphOptimizer { private: std::unique_ptr NewOptimizer(const string& optimizer); + DeviceBase* const cpu_device_; // may be NULL RewriterConfig cfg_; }; bool MetaOptimizerEnabled(const RewriterConfig& cfg); +// Run the meta optimizer. +// +// If is non-null, it is the device to be used for executing ops +// during constant folding; if NULL, a new device is created for doing constant +// folding. For performance, it is recommended to pass in an existing cpu_device +// when possible. Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg, - Cluster* cluster, GraphDef* optimized_graph); + DeviceBase* cpu_device, Cluster* cluster, + GraphDef* optimized_graph); } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i index a8067467d91..12c5fce60f9 100644 --- a/tensorflow/python/grappler/tf_optimizer.i +++ b/tensorflow/python/grappler/tf_optimizer.i @@ -55,6 +55,7 @@ limitations under the License. #include #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/lib/core/status.h" + #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item_builder.h" @@ -73,10 +74,11 @@ PyObject* TF_OptimizeGraph( std::unique_ptr grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); std::unordered_map device_map; + tensorflow::DeviceBase* cpu_device = nullptr; tensorflow::grappler::VirtualCluster cluster(device_map); tensorflow::GraphDef out_graph; tensorflow::Status status = tensorflow::grappler::RunMetaOptimizer( - *grappler_item, rewriter_config, &cluster, &out_graph); + *grappler_item, rewriter_config, cpu_device, &cluster, &out_graph); tensorflow::Set_TF_Status_from_Status(out_status, status); string out_graph_str = out_graph.SerializeAsString(); PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),