Pass in the CPU device to grappler, instead of making a new one, when possible.

PiperOrigin-RevId: 167944850
This commit is contained in:
A. Unique TensorFlower 2017-09-07 19:03:55 -07:00 committed by TensorFlower Gardener
parent 03ebcfdcc8
commit 2b5011625b
8 changed files with 61 additions and 30 deletions

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/util.h"
#ifndef IS_MOBILE_PLATFORM
@ -338,14 +339,24 @@ Status SimpleGraphExecutionState::OptimizeGraph(
}
std::unordered_map<string, DeviceProperties> device_map;
Device* cpu_device = nullptr;
for (const auto& device : device_set_->devices()) {
device_map[device->name()] =
grappler::GetDeviceInfo(device->parsed_name());
if (device->parsed_name().id == 0 &&
StringPiece(device->parsed_name().type) == "CPU" &&
device->GetAllocator(AllocatorAttributes()) != nullptr) {
cpu_device = device;
}
}
if (cpu_device == nullptr) {
return errors::Internal(
"Unable to find CPU device needed for constant folding");
}
grappler::VirtualCluster cluster(device_map);
GraphDef new_graph;
TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(item, rewrite_options,
&cluster, &new_graph));
TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
item, rewrite_options, cpu_device, &cluster, &new_graph));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
optimized_graph->reset(new Graph(OpRegistry::Global()));

View File

@ -310,6 +310,7 @@ cc_library(
":layout_optimizer",
":memory_optimizer",
":model_pruner",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",

View File

@ -95,8 +95,8 @@ class DeviceSimple : public DeviceBase {
};
} // namespace
ConstantFolding::ConstantFolding() {
ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
: cpu_device_(cpu_device) {
resource_mgr_.reset(new ResourceMgr());
}
@ -381,11 +381,11 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node,
TensorVector* output) const {
Status status;
auto op_kernel =
CreateOpKernel("CPU", device_.get(), device_->GetAllocator({}), node,
CreateOpKernel("CPU", cpu_device_, cpu_device_->GetAllocator({}), node,
TF_GRAPH_DEF_VERSION, &status);
TF_RETURN_IF_ERROR(status);
OpKernelContext::Params params;
params.device = device_.get();
params.device = cpu_device_;
params.frame_iter = FrameAndIter(0, 0);
params.inputs = &inputs;
params.op_kernel = op_kernel.get();
@ -845,8 +845,11 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
graph_ = item.graph;
node_map_.reset(new NodeMap(&graph_));
nodes_to_preserve_ = item.NodesToPreserve();
device_.reset(new DeviceSimple());
*output = GraphDef();
if (cpu_device_ == nullptr) {
owned_device_.reset(new DeviceSimple());
cpu_device_ = owned_device_.get();
}
bool has_feed = !item.feed.empty();
has_fetch_ = !item.fetch.empty();

View File

@ -32,7 +32,7 @@ const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
// Constant folding optimization for a graph.
class ConstantFolding : public GraphOptimizer {
public:
ConstantFolding();
ConstantFolding(DeviceBase* cpu_device);
~ConstantFolding() override {}
@ -69,7 +69,10 @@ class ConstantFolding : public GraphOptimizer {
const GraphProperties& properties) const;
Status SimplifyGraph(GraphDef* output, const GraphProperties& properties);
std::unique_ptr<DeviceBase> device_;
// Points to an externally provided device or to owned_device_;
DeviceBase* cpu_device_;
std::unique_ptr<DeviceBase> owned_device_;
std::unique_ptr<ResourceMgr> resource_mgr_;
GraphDef graph_;
std::unique_ptr<NodeMap> node_map_;

View File

@ -57,7 +57,7 @@ TEST_F(ConstantFoldingTest, SimpleFolding) {
item.fetch.push_back("d");
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -103,7 +103,7 @@ TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
item.fetch.push_back("f");
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -152,7 +152,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
item.fetch.push_back("e");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -195,7 +195,7 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -252,7 +252,7 @@ TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
item.fetch.push_back("i2");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -326,7 +326,7 @@ TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
}
item.fetch = outputs;
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -364,7 +364,7 @@ TEST_F(ConstantFoldingTest, ShapeMaterialization) {
item.fetch.push_back("p2");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -399,7 +399,7 @@ TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -461,7 +461,7 @@ TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -537,7 +537,7 @@ TEST_F(ConstantFoldingTest, SwitchNodes) {
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -605,7 +605,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"};
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -679,7 +679,7 @@ TEST_F(ConstantFoldingTest, NoOpReduction) {
item.fetch.push_back("s");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -738,7 +738,7 @@ TEST_F(ConstantFoldingTest, NoOpReshape) {
item.fetch = {"s1", "s2", "s3", "s4"};
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -785,7 +785,7 @@ TEST_F(ConstantFoldingTest, Packing) {
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding fold;
ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);

View File

@ -37,7 +37,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
graph_optimizer.reset(new ModelPruner());
}
if (optimizer == "constfold") {
graph_optimizer.reset(new ConstantFolding());
graph_optimizer.reset(new ConstantFolding(cpu_device_));
}
if (optimizer == "layout") {
graph_optimizer.reset(new LayoutOptimizer());
@ -64,7 +64,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
if (cfg_.constant_folding() != RewriterConfig::OFF) {
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new ConstantFolding()));
std::unique_ptr<GraphOptimizer>(new ConstantFolding(cpu_device_)));
}
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers.push_back(
@ -144,8 +144,9 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
}
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,
Cluster* cluster, GraphDef* optimized_graph) {
MetaOptimizer optimizer(cfg);
DeviceBase* cpu_device, Cluster* cluster,
GraphDef* optimized_graph) {
MetaOptimizer optimizer(cpu_device, cfg);
return optimizer.Optimize(cluster, item, optimized_graph);
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/lib/core/status.h"
@ -27,7 +28,8 @@ namespace grappler {
// Run the other grappler optimizers based on the specified rewriter config.
class MetaOptimizer : public GraphOptimizer {
public:
MetaOptimizer(const RewriterConfig& cfg) : cfg_(cfg) {}
MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg)
: cpu_device_(cpu_device), cfg_(cfg) {}
~MetaOptimizer() override {}
string name() const override { return "meta_optimizer"; };
@ -40,13 +42,21 @@ class MetaOptimizer : public GraphOptimizer {
private:
std::unique_ptr<GraphOptimizer> NewOptimizer(const string& optimizer);
DeviceBase* const cpu_device_; // may be NULL
RewriterConfig cfg_;
};
bool MetaOptimizerEnabled(const RewriterConfig& cfg);
// Run the meta optimizer.
//
// If <cpu_device> is non-null, it is the device to be used for executing ops
// during constant folding; if NULL, a new device is created for doing constant
// folding. For performance, it is recommended to pass in an existing cpu_device
// when possible.
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,
Cluster* cluster, GraphDef* optimized_graph);
DeviceBase* cpu_device, Cluster* cluster,
GraphDef* optimized_graph);
} // namespace grappler
} // namespace tensorflow

View File

@ -55,6 +55,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
@ -73,10 +74,11 @@ PyObject* TF_OptimizeGraph(
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
tensorflow::DeviceBase* cpu_device = nullptr;
tensorflow::grappler::VirtualCluster cluster(device_map);
tensorflow::GraphDef out_graph;
tensorflow::Status status = tensorflow::grappler::RunMetaOptimizer(
*grappler_item, rewriter_config, &cluster, &out_graph);
*grappler_item, rewriter_config, cpu_device, &cluster, &out_graph);
tensorflow::Set_TF_Status_from_Status(out_status, status);
string out_graph_str = out_graph.SerializeAsString();
PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),