Pass in the CPU device to grappler, instead of making a new one, when possible.
PiperOrigin-RevId: 167944850
This commit is contained in:
parent
03ebcfdcc8
commit
2b5011625b
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
#ifndef IS_MOBILE_PLATFORM
|
||||
@ -338,14 +339,24 @@ Status SimpleGraphExecutionState::OptimizeGraph(
|
||||
}
|
||||
|
||||
std::unordered_map<string, DeviceProperties> device_map;
|
||||
Device* cpu_device = nullptr;
|
||||
for (const auto& device : device_set_->devices()) {
|
||||
device_map[device->name()] =
|
||||
grappler::GetDeviceInfo(device->parsed_name());
|
||||
if (device->parsed_name().id == 0 &&
|
||||
StringPiece(device->parsed_name().type) == "CPU" &&
|
||||
device->GetAllocator(AllocatorAttributes()) != nullptr) {
|
||||
cpu_device = device;
|
||||
}
|
||||
}
|
||||
if (cpu_device == nullptr) {
|
||||
return errors::Internal(
|
||||
"Unable to find CPU device needed for constant folding");
|
||||
}
|
||||
grappler::VirtualCluster cluster(device_map);
|
||||
GraphDef new_graph;
|
||||
TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(item, rewrite_options,
|
||||
&cluster, &new_graph));
|
||||
TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
|
||||
item, rewrite_options, cpu_device, &cluster, &new_graph));
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
optimized_graph->reset(new Graph(OpRegistry::Global()));
|
||||
|
@ -310,6 +310,7 @@ cc_library(
|
||||
":layout_optimizer",
|
||||
":memory_optimizer",
|
||||
":model_pruner",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
|
@ -95,8 +95,8 @@ class DeviceSimple : public DeviceBase {
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
ConstantFolding::ConstantFolding() {
|
||||
ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
|
||||
: cpu_device_(cpu_device) {
|
||||
resource_mgr_.reset(new ResourceMgr());
|
||||
}
|
||||
|
||||
@ -381,11 +381,11 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node,
|
||||
TensorVector* output) const {
|
||||
Status status;
|
||||
auto op_kernel =
|
||||
CreateOpKernel("CPU", device_.get(), device_->GetAllocator({}), node,
|
||||
CreateOpKernel("CPU", cpu_device_, cpu_device_->GetAllocator({}), node,
|
||||
TF_GRAPH_DEF_VERSION, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
OpKernelContext::Params params;
|
||||
params.device = device_.get();
|
||||
params.device = cpu_device_;
|
||||
params.frame_iter = FrameAndIter(0, 0);
|
||||
params.inputs = &inputs;
|
||||
params.op_kernel = op_kernel.get();
|
||||
@ -845,8 +845,11 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
graph_ = item.graph;
|
||||
node_map_.reset(new NodeMap(&graph_));
|
||||
nodes_to_preserve_ = item.NodesToPreserve();
|
||||
device_.reset(new DeviceSimple());
|
||||
*output = GraphDef();
|
||||
if (cpu_device_ == nullptr) {
|
||||
owned_device_.reset(new DeviceSimple());
|
||||
cpu_device_ = owned_device_.get();
|
||||
}
|
||||
|
||||
bool has_feed = !item.feed.empty();
|
||||
has_fetch_ = !item.fetch.empty();
|
||||
|
@ -32,7 +32,7 @@ const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
|
||||
// Constant folding optimization for a graph.
|
||||
class ConstantFolding : public GraphOptimizer {
|
||||
public:
|
||||
ConstantFolding();
|
||||
ConstantFolding(DeviceBase* cpu_device);
|
||||
|
||||
~ConstantFolding() override {}
|
||||
|
||||
@ -69,7 +69,10 @@ class ConstantFolding : public GraphOptimizer {
|
||||
const GraphProperties& properties) const;
|
||||
Status SimplifyGraph(GraphDef* output, const GraphProperties& properties);
|
||||
|
||||
std::unique_ptr<DeviceBase> device_;
|
||||
// Points to an externally provided device or to owned_device_;
|
||||
DeviceBase* cpu_device_;
|
||||
std::unique_ptr<DeviceBase> owned_device_;
|
||||
|
||||
std::unique_ptr<ResourceMgr> resource_mgr_;
|
||||
GraphDef graph_;
|
||||
std::unique_ptr<NodeMap> node_map_;
|
||||
|
@ -57,7 +57,7 @@ TEST_F(ConstantFoldingTest, SimpleFolding) {
|
||||
item.fetch.push_back("d");
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -103,7 +103,7 @@ TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
|
||||
item.fetch.push_back("f");
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -152,7 +152,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
|
||||
item.fetch.push_back("e");
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -195,7 +195,7 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -252,7 +252,7 @@ TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
|
||||
item.fetch.push_back("i2");
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -326,7 +326,7 @@ TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
|
||||
}
|
||||
|
||||
item.fetch = outputs;
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -364,7 +364,7 @@ TEST_F(ConstantFoldingTest, ShapeMaterialization) {
|
||||
item.fetch.push_back("p2");
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -399,7 +399,7 @@ TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -461,7 +461,7 @@ TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -537,7 +537,7 @@ TEST_F(ConstantFoldingTest, SwitchNodes) {
|
||||
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -605,7 +605,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
|
||||
item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"};
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -679,7 +679,7 @@ TEST_F(ConstantFoldingTest, NoOpReduction) {
|
||||
item.fetch.push_back("s");
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -738,7 +738,7 @@ TEST_F(ConstantFoldingTest, NoOpReshape) {
|
||||
item.fetch = {"s1", "s2", "s3", "s4"};
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -785,7 +785,7 @@ TEST_F(ConstantFoldingTest, Packing) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
ConstantFolding fold(nullptr /* cpu_device */);
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
@ -37,7 +37,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
|
||||
graph_optimizer.reset(new ModelPruner());
|
||||
}
|
||||
if (optimizer == "constfold") {
|
||||
graph_optimizer.reset(new ConstantFolding());
|
||||
graph_optimizer.reset(new ConstantFolding(cpu_device_));
|
||||
}
|
||||
if (optimizer == "layout") {
|
||||
graph_optimizer.reset(new LayoutOptimizer());
|
||||
@ -64,7 +64,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
if (cfg_.constant_folding() != RewriterConfig::OFF) {
|
||||
optimizers.push_back(
|
||||
std::unique_ptr<GraphOptimizer>(new ConstantFolding()));
|
||||
std::unique_ptr<GraphOptimizer>(new ConstantFolding(cpu_device_)));
|
||||
}
|
||||
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
|
||||
optimizers.push_back(
|
||||
@ -144,8 +144,9 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
|
||||
}
|
||||
|
||||
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,
|
||||
Cluster* cluster, GraphDef* optimized_graph) {
|
||||
MetaOptimizer optimizer(cfg);
|
||||
DeviceBase* cpu_device, Cluster* cluster,
|
||||
GraphDef* optimized_graph) {
|
||||
MetaOptimizer optimizer(cpu_device, cfg);
|
||||
return optimizer.Optimize(cluster, item, optimized_graph);
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_
|
||||
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_
|
||||
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -27,7 +28,8 @@ namespace grappler {
|
||||
// Run the other grappler optimizers based on the specified rewriter config.
|
||||
class MetaOptimizer : public GraphOptimizer {
|
||||
public:
|
||||
MetaOptimizer(const RewriterConfig& cfg) : cfg_(cfg) {}
|
||||
MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg)
|
||||
: cpu_device_(cpu_device), cfg_(cfg) {}
|
||||
~MetaOptimizer() override {}
|
||||
|
||||
string name() const override { return "meta_optimizer"; };
|
||||
@ -40,13 +42,21 @@ class MetaOptimizer : public GraphOptimizer {
|
||||
|
||||
private:
|
||||
std::unique_ptr<GraphOptimizer> NewOptimizer(const string& optimizer);
|
||||
DeviceBase* const cpu_device_; // may be NULL
|
||||
RewriterConfig cfg_;
|
||||
};
|
||||
|
||||
bool MetaOptimizerEnabled(const RewriterConfig& cfg);
|
||||
|
||||
// Run the meta optimizer.
|
||||
//
|
||||
// If <cpu_device> is non-null, it is the device to be used for executing ops
|
||||
// during constant folding; if NULL, a new device is created for doing constant
|
||||
// folding. For performance, it is recommended to pass in an existing cpu_device
|
||||
// when possible.
|
||||
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,
|
||||
Cluster* cluster, GraphDef* optimized_graph);
|
||||
DeviceBase* cpu_device, Cluster* cluster,
|
||||
GraphDef* optimized_graph);
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -55,6 +55,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
@ -73,10 +74,11 @@ PyObject* TF_OptimizeGraph(
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
|
||||
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
|
||||
tensorflow::DeviceBase* cpu_device = nullptr;
|
||||
tensorflow::grappler::VirtualCluster cluster(device_map);
|
||||
tensorflow::GraphDef out_graph;
|
||||
tensorflow::Status status = tensorflow::grappler::RunMetaOptimizer(
|
||||
*grappler_item, rewriter_config, &cluster, &out_graph);
|
||||
*grappler_item, rewriter_config, cpu_device, &cluster, &out_graph);
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||
string out_graph_str = out_graph.SerializeAsString();
|
||||
PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),
|
||||
|
Loading…
Reference in New Issue
Block a user