Add a constant folding pass to grappler.

Change: 151668925
This commit is contained in:
Yao Zhang 2017-03-29 22:23:14 -08:00 committed by TensorFlower Gardener
parent 649ac98246
commit 3db6b68b2f
5 changed files with 571 additions and 1 deletions

View File

@ -24,6 +24,42 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "constant_folding",
srcs = ["constant_folding.cc"],
hdrs = [
"constant_folding.h",
],
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
],
)
cc_test(
name = "constant_folding_test",
srcs = ["constant_folding_test.cc"],
deps = [
":constant_folding",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
],
)
cc_library(
name = "graph_rewriter",
srcs = ["graph_rewriter.cc"],

View File

@ -0,0 +1,313 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace grappler {
using TensorVector = gtl::InlinedVector<TensorValue, 4>;
namespace {
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
public:
explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {}
~EigenThreadPoolWrapper() override {}
void Schedule(std::function<void()> fn) override {
pool_->Schedule(std::move(fn));
}
int NumThreads() const override { return pool_->NumThreads(); }
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
private:
thread::ThreadPool* pool_ = nullptr;
};
class DeviceSimple : public DeviceBase {
public:
DeviceSimple() : DeviceBase(nullptr) {
eigen_worker_threads_.num_threads = 1;
eigen_worker_threads_.workers = new thread::ThreadPool(
Env::Default(), "constant_folding", eigen_worker_threads_.num_threads);
eigen_threadpool_wrapper_.reset(
new EigenThreadPoolWrapper(eigen_worker_threads_.workers));
eigen_device_.reset(new Eigen::ThreadPoolDevice(
eigen_threadpool_wrapper_.get(), eigen_worker_threads_.num_threads));
set_tensorflow_cpu_worker_threads(&eigen_worker_threads_);
set_eigen_cpu_device(eigen_device_.get());
}
~DeviceSimple() override {
eigen_threadpool_wrapper_.reset();
eigen_device_.reset();
delete eigen_worker_threads_.workers;
}
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override {
Tensor parsed(tensor_proto.dtype());
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
tensor_proto.DebugString());
}
*tensor = parsed;
return Status::OK();
}
Allocator* GetAllocator(AllocatorAttributes attr) override {
return cpu_allocator();
}
private:
DeviceBase::CpuWorkerThreads eigen_worker_threads_;
std::unique_ptr<Eigen::ThreadPoolInterface> eigen_threadpool_wrapper_;
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
};
Status NumOutputs(const NodeDef& node, int* num_outputs) {
const OpDef* op_def = nullptr;
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
if (node.op() == "ConcatOffset") {
*num_outputs = node.attr().at("N").i();
} else {
*num_outputs = op_def->output_arg_size();
}
return Status::OK();
}
} // namespace
bool ConstantFolding::IsConst(const NodeDef& node) const {
return node.op() == "Const";
}
bool ConstantFolding::IsFoldable(const NodeDef& node) const {
DeviceTypeVector device_types;
auto status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node,
&device_types);
if (!status.ok()) {
return false;
}
// Only fold ops with a CPU implementation available.
if (device_types[0] != DeviceType(DEVICE_CPU)) {
return false;
}
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
return false;
}
if (ops_to_preserve_.find(node.op()) != ops_to_preserve_.end()) {
return false;
}
// Don't fold stateful ops such as TruncatedNormal.
const OpDef* op_def = nullptr;
status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
if (!status.ok()) {
return false;
}
if (op_def->is_stateful()) {
return false;
}
if (op_def->output_arg_size() == 0) {
return false;
}
// Folding not applicable to ops with no inputs.
if (node.input().empty()) {
return false;
}
for (const auto& input : node.input()) {
bool is_const = IsConst(*node_map_->GetNode(input));
if (!is_const) {
return false;
}
}
return true;
}
NodeDef ConstantFolding::CreateNodeDef(const string& name,
const TensorValue& tensor) {
NodeDef node;
node.set_name(name);
node.set_op("Const");
AttrValue attr_output_shape;
auto output_shape = attr_output_shape.mutable_list()->add_shape();
TensorShapeProto shape;
tensor->shape().AsProto(&shape);
*output_shape = shape;
node.mutable_attr()->insert({"_output_shapes", attr_output_shape});
AttrValue attr_type;
attr_type.set_type(tensor->dtype());
node.mutable_attr()->insert({"dtype", attr_type});
AttrValue attr_tensor;
tensor->AsProtoTensorContent(attr_tensor.mutable_tensor());
node.mutable_attr()->insert({"value", attr_tensor});
return node;
}
Status ConstantFolding::EvaluateNode(const NodeDef& node,
const TensorVector& inputs,
TensorVector* output) {
Status status;
auto op_kernel =
CreateOpKernel("CPU", device_.get(), device_->GetAllocator({}), node,
TF_GRAPH_DEF_VERSION, &status);
TF_RETURN_IF_ERROR(status);
OpKernelContext::Params params;
params.device = device_.get();
params.frame_iter = FrameAndIter(0, 0);
params.inputs = &inputs;
params.op_kernel = op_kernel.get();
int num_outputs;
TF_RETURN_IF_ERROR(NumOutputs(node, &num_outputs));
gtl::InlinedVector<AllocatorAttributes, 4> output_attrs;
for (int i = 0; i < num_outputs; i++) {
AllocatorAttributes attr;
attr.set_on_host(true);
output_attrs.push_back(attr);
}
params.output_attr_array = output_attrs.data();
OpKernelContext op_context(&params);
op_kernel->Compute(&op_context);
for (int i = 0; i < num_outputs; i++) {
output->push_back(op_context.release_output(i));
}
return Status::OK();
}
Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
std::vector<NodeDef>* outputs) {
TensorVector inputs;
for (const auto& input : node.input()) {
TensorVector output;
TF_RETURN_IF_ERROR(
EvaluateNode(*node_map_->GetNode(input), TensorVector(), &output));
inputs.push_back(output[0]);
}
TensorVector output_tensors;
TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
for (const auto& input : inputs) {
delete input.tensor;
}
if (output_tensors.empty()) {
Status(error::INVALID_ARGUMENT, "Expected at least one output.");
}
for (int i = 0; i < output_tensors.size(); i++) {
string node_name = strings::StrCat(
AddPrefixToNodeName(node.name(), kConstantFoldingConst));
if (output_tensors.size() > 1) {
node_name = strings::StrCat(node_name, "-", i);
}
outputs->push_back(CreateNodeDef(node_name, output_tensors[i]));
delete output_tensors[i].tensor;
}
return Status::OK();
}
Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
std::vector<NodeDef> const_nodes;
TF_RETURN_IF_ERROR(EvaluateOneFoldable(node, &const_nodes));
auto outputs = node_map_->GetOutputs(node.name());
for (const auto& const_node : const_nodes) {
NodeDef* added_node = output->add_node();
*added_node = const_node;
node_map_->AddNode(added_node->name(), added_node);
}
for (const auto& output : outputs) {
for (int i = 0; i < output->input_size(); i++) {
int position;
string node_name = ParseNodeName(output->input(i), &position);
if (node_name == node.name()) {
if (position < 0) {
*output->mutable_input(i) =
strings::StrCat("^", const_nodes[0].name());
} else {
*output->mutable_input(i) = const_nodes[position].name();
}
}
}
}
return Status::OK();
}
Status ConstantFolding::FoldGraph(GraphDef* output) {
std::set<string> processed_nodes;
while (1) {
int previous_processed = processed_nodes.size();
for (const auto& node : graph_.node()) {
if (IsFoldable(node) &&
processed_nodes.find(node.name()) == processed_nodes.end()) {
TF_RETURN_IF_ERROR(FoldNode(node, output));
processed_nodes.insert(node.name());
}
}
int current_processed = processed_nodes.size();
LOG(INFO) << "Previous number of processed nodes: " << previous_processed
<< "; Current number of processed nodes: " << current_processed;
if (current_processed == previous_processed) {
break;
}
}
// Build the graph after constant folding. Note that we keep all processed
// nodes in the graph in case users need to fetch their values.
for (const auto& node : graph_.node()) {
auto added_node = output->add_node();
*added_node = node;
}
return Status::OK();
}
Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
graph_ = item.graph;
LOG(INFO) << "Initial graph size: " << item.graph.node_size();
node_map_.reset(new NodeMap(&graph_));
for (const auto& node : item.fetch) {
nodes_to_preserve_.insert(NodeName(node));
}
device_.reset(new DeviceSimple());
TF_RETURN_IF_ERROR(FoldGraph(output));
LOG(INFO) << "Optimized graph size: " << output->node_size();
return Status::OK();
}
void ConstantFolding::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) {
// Nothing to do for ConstantFolding.
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,73 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
const char kConstantFoldingConst[] = "ConstantFolding";
// Contant folding optimization for a graph.
class ConstantFolding : public GraphOptimizer {
public:
ConstantFolding() {}
~ConstantFolding() override {}
string name() const override { return "constant folding"; };
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override;
private:
bool IsConst(const NodeDef& node) const;
bool IsFoldable(const NodeDef& node) const;
NodeDef CreateNodeDef(const string& name, const TensorValue& tensor);
Status EvaluateNode(const NodeDef& node,
const gtl::InlinedVector<TensorValue, 4>& inputs,
gtl::InlinedVector<TensorValue, 4>* output);
Status EvaluateOneFoldable(const NodeDef& node,
std::vector<NodeDef>* outputs);
Status FoldNode(const NodeDef& node, GraphDef* output);
Status FoldGraph(GraphDef* output);
std::unique_ptr<DeviceBase> device_;
GraphDef graph_;
std::unique_ptr<NodeMap> node_map_;
std::set<string> nodes_to_preserve_;
std::set<string> ops_to_preserve_ = {"Save", "SaveV2", "SaveSlices",
"Restore", "RestoreV2", "RestoreSlice"};
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_

View File

@ -0,0 +1,148 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace grappler {
namespace {
class ConstantFoldingTest : public ::testing::Test {
protected:
std::vector<Tensor> EvaluateNodes(const GraphDef& graph,
const std::vector<string>& fetch) {
SessionOptions options;
std::unique_ptr<tensorflow::Session> session(NewSession(options));
TF_CHECK_OK(session->Create(graph));
RunOptions run_options;
std::vector<Tensor> output_tensors;
TF_CHECK_OK(
session->Run(run_options, {}, fetch, fetch, &output_tensors, nullptr));
TF_CHECK_OK(session->Close());
return output_tensors;
}
};
TEST_F(ConstantFoldingTest, SimpleFolding) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
Output c = ops::AddN(s.WithOpName("c"), {a, b});
Output d = ops::AddN(s.WithOpName("d"), {b, c});
GrapplerItem item;
item.fetch.push_back("d");
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ConstantFolding fold;
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(5, output.node_size());
const NodeDef& new_c = output.node(0);
EXPECT_EQ("ConstantFolding-c", new_c.name());
EXPECT_EQ("Const", new_c.op());
const NodeDef& new_a = output.node(1);
EXPECT_EQ("a", new_a.name());
const NodeDef& new_b = output.node(2);
EXPECT_EQ("b", new_b.name());
const NodeDef& old_c = output.node(3);
EXPECT_EQ("c", old_c.name());
const NodeDef& new_d = output.node(4);
EXPECT_EQ("d", new_d.name());
EXPECT_EQ("ConstantFolding-c", new_d.input(1));
std::vector<string> fetch = {"a", "b", "c", "d"};
auto tensors_expected = EvaluateNodes(item.graph, fetch);
auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(4, tensors_expected.size());
EXPECT_EQ(4, tensors.size());
for (int i = 0; i < 4; i++) {
test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
}
}
TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 10, {3});
auto b = ops::Unique(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), {b.y});
Output d = ops::Identity(s.WithOpName("d"), {b.idx});
GrapplerItem item;
item.fetch.push_back("c");
item.fetch.push_back("d");
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ConstantFolding fold;
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(6, output.node_size());
const NodeDef& new_b_0 = output.node(0);
EXPECT_EQ("ConstantFolding-b-0", new_b_0.name());
EXPECT_EQ("Const", new_b_0.op());
const NodeDef& new_b_1 = output.node(1);
EXPECT_EQ("ConstantFolding-b-1", new_b_1.name());
EXPECT_EQ("Const", new_b_1.op());
const NodeDef& new_a = output.node(2);
EXPECT_EQ("a", new_a.name());
const NodeDef& new_b = output.node(3);
EXPECT_EQ("b", new_b.name());
const NodeDef& new_c = output.node(4);
EXPECT_EQ("c", new_c.name());
EXPECT_EQ("ConstantFolding-b-0", new_c.input(0));
const NodeDef& new_d = output.node(5);
EXPECT_EQ("d", new_d.name());
EXPECT_EQ("ConstantFolding-b-1", new_d.input(0));
std::vector<string> fetch = {"a", "b", "c", "d"};
auto tensors_expected = EvaluateNodes(item.graph, fetch);
auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(4, tensors_expected.size());
EXPECT_EQ(4, tensors.size());
for (int i = 0; i < 4; i++) {
test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
}
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -29,7 +29,7 @@ NodeMap::NodeMap(GraphDef* graph) : graph_(graph) {
auto node = graph_->mutable_node(i);
nodes_.insert(std::make_pair(node->name(), node));
for (const auto& input : node->input()) {
outputs_[input].insert(nodes_[node->name()]);
outputs_[NodeName(input)].insert(nodes_[node->name()]);
}
}
}