From 3db6b68b2f73caeaa71317926a0ea4d5f13688f8 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 29 Mar 2017 22:23:14 -0800 Subject: [PATCH] Add a constant folding pass to grappler. Change: 151668925 --- tensorflow/core/grappler/optimizers/BUILD | 36 ++ .../grappler/optimizers/constant_folding.cc | 313 ++++++++++++++++++ .../grappler/optimizers/constant_folding.h | 73 ++++ .../optimizers/constant_folding_test.cc | 148 +++++++++ tensorflow/core/grappler/utils.cc | 2 +- 5 files changed, 571 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/grappler/optimizers/constant_folding.cc create mode 100644 tensorflow/core/grappler/optimizers/constant_folding.h create mode 100644 tensorflow/core/grappler/optimizers/constant_folding_test.cc diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 02716b3f781..d09a3c4e304 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -24,6 +24,42 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +cc_library( + name = "constant_folding", + srcs = ["constant_folding.cc"], + hdrs = [ + "constant_folding.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + ], +) + +cc_test( + name = "constant_folding_test", + srcs = ["constant_folding_test.cc"], + deps = [ + ":constant_folding", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + cc_library( name = "graph_rewriter", srcs = ["graph_rewriter.cc"], diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc new file mode 100644 index 00000000000..49891e2a780 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -0,0 +1,313 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace grappler { +using TensorVector = gtl::InlinedVector; + +namespace { +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + thread::ThreadPool* pool_ = nullptr; +}; + +class DeviceSimple : public DeviceBase { + public: + DeviceSimple() : DeviceBase(nullptr) { + eigen_worker_threads_.num_threads = 1; + eigen_worker_threads_.workers = new thread::ThreadPool( + Env::Default(), "constant_folding", eigen_worker_threads_.num_threads); + eigen_threadpool_wrapper_.reset( + new EigenThreadPoolWrapper(eigen_worker_threads_.workers)); + eigen_device_.reset(new Eigen::ThreadPoolDevice( + eigen_threadpool_wrapper_.get(), eigen_worker_threads_.num_threads)); + set_tensorflow_cpu_worker_threads(&eigen_worker_threads_); + set_eigen_cpu_device(eigen_device_.get()); + } + ~DeviceSimple() override { + eigen_threadpool_wrapper_.reset(); + eigen_device_.reset(); + delete eigen_worker_threads_.workers; + } + Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override { + Tensor parsed(tensor_proto.dtype()); + if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + tensor_proto.DebugString()); + } + *tensor = parsed; + return Status::OK(); + } + Allocator* GetAllocator(AllocatorAttributes attr) override { + return cpu_allocator(); + } + + private: + DeviceBase::CpuWorkerThreads eigen_worker_threads_; + std::unique_ptr eigen_threadpool_wrapper_; + std::unique_ptr eigen_device_; +}; + +Status NumOutputs(const NodeDef& node, int* num_outputs) { + const OpDef* op_def = nullptr; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def)); + if (node.op() == "ConcatOffset") { + *num_outputs = node.attr().at("N").i(); + } else { + *num_outputs = op_def->output_arg_size(); + } + return Status::OK(); +} +} // namespace + +bool ConstantFolding::IsConst(const NodeDef& node) const { + return node.op() == "Const"; +} + +bool ConstantFolding::IsFoldable(const NodeDef& node) const { + DeviceTypeVector device_types; + auto status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node, + &device_types); + if (!status.ok()) { + return false; + } + // Only fold ops with a CPU implementation available. + if (device_types[0] != DeviceType(DEVICE_CPU)) { + return false; + } + + if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { + return false; + } + + if (ops_to_preserve_.find(node.op()) != ops_to_preserve_.end()) { + return false; + } + + // Don't fold stateful ops such as TruncatedNormal. + const OpDef* op_def = nullptr; + status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + if (!status.ok()) { + return false; + } + + if (op_def->is_stateful()) { + return false; + } + + if (op_def->output_arg_size() == 0) { + return false; + } + + // Folding not applicable to ops with no inputs. + if (node.input().empty()) { + return false; + } + + for (const auto& input : node.input()) { + bool is_const = IsConst(*node_map_->GetNode(input)); + if (!is_const) { + return false; + } + } + return true; +} + +NodeDef ConstantFolding::CreateNodeDef(const string& name, + const TensorValue& tensor) { + NodeDef node; + node.set_name(name); + node.set_op("Const"); + AttrValue attr_output_shape; + auto output_shape = attr_output_shape.mutable_list()->add_shape(); + TensorShapeProto shape; + tensor->shape().AsProto(&shape); + *output_shape = shape; + node.mutable_attr()->insert({"_output_shapes", attr_output_shape}); + + AttrValue attr_type; + attr_type.set_type(tensor->dtype()); + node.mutable_attr()->insert({"dtype", attr_type}); + + AttrValue attr_tensor; + tensor->AsProtoTensorContent(attr_tensor.mutable_tensor()); + node.mutable_attr()->insert({"value", attr_tensor}); + return node; +} + +Status ConstantFolding::EvaluateNode(const NodeDef& node, + const TensorVector& inputs, + TensorVector* output) { + Status status; + auto op_kernel = + CreateOpKernel("CPU", device_.get(), device_->GetAllocator({}), node, + TF_GRAPH_DEF_VERSION, &status); + TF_RETURN_IF_ERROR(status); + OpKernelContext::Params params; + params.device = device_.get(); + params.frame_iter = FrameAndIter(0, 0); + params.inputs = &inputs; + params.op_kernel = op_kernel.get(); + int num_outputs; + TF_RETURN_IF_ERROR(NumOutputs(node, &num_outputs)); + gtl::InlinedVector output_attrs; + for (int i = 0; i < num_outputs; i++) { + AllocatorAttributes attr; + attr.set_on_host(true); + output_attrs.push_back(attr); + } + params.output_attr_array = output_attrs.data(); + OpKernelContext op_context(¶ms); + op_kernel->Compute(&op_context); + for (int i = 0; i < num_outputs; i++) { + output->push_back(op_context.release_output(i)); + } + return Status::OK(); +} + +Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, + std::vector* outputs) { + TensorVector inputs; + for (const auto& input : node.input()) { + TensorVector output; + TF_RETURN_IF_ERROR( + EvaluateNode(*node_map_->GetNode(input), TensorVector(), &output)); + inputs.push_back(output[0]); + } + + TensorVector output_tensors; + TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors)); + for (const auto& input : inputs) { + delete input.tensor; + } + if (output_tensors.empty()) { + Status(error::INVALID_ARGUMENT, "Expected at least one output."); + } + for (int i = 0; i < output_tensors.size(); i++) { + string node_name = strings::StrCat( + AddPrefixToNodeName(node.name(), kConstantFoldingConst)); + if (output_tensors.size() > 1) { + node_name = strings::StrCat(node_name, "-", i); + } + outputs->push_back(CreateNodeDef(node_name, output_tensors[i])); + delete output_tensors[i].tensor; + } + return Status::OK(); +} + +Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) { + std::vector const_nodes; + TF_RETURN_IF_ERROR(EvaluateOneFoldable(node, &const_nodes)); + + auto outputs = node_map_->GetOutputs(node.name()); + for (const auto& const_node : const_nodes) { + NodeDef* added_node = output->add_node(); + *added_node = const_node; + node_map_->AddNode(added_node->name(), added_node); + } + for (const auto& output : outputs) { + for (int i = 0; i < output->input_size(); i++) { + int position; + string node_name = ParseNodeName(output->input(i), &position); + if (node_name == node.name()) { + if (position < 0) { + *output->mutable_input(i) = + strings::StrCat("^", const_nodes[0].name()); + } else { + *output->mutable_input(i) = const_nodes[position].name(); + } + } + } + } + return Status::OK(); +} + +Status ConstantFolding::FoldGraph(GraphDef* output) { + std::set processed_nodes; + while (1) { + int previous_processed = processed_nodes.size(); + for (const auto& node : graph_.node()) { + if (IsFoldable(node) && + processed_nodes.find(node.name()) == processed_nodes.end()) { + TF_RETURN_IF_ERROR(FoldNode(node, output)); + processed_nodes.insert(node.name()); + } + } + int current_processed = processed_nodes.size(); + LOG(INFO) << "Previous number of processed nodes: " << previous_processed + << "; Current number of processed nodes: " << current_processed; + if (current_processed == previous_processed) { + break; + } + } + + // Build the graph after constant folding. Note that we keep all processed + // nodes in the graph in case users need to fetch their values. + for (const auto& node : graph_.node()) { + auto added_node = output->add_node(); + *added_node = node; + } + return Status::OK(); +} + +Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + graph_ = item.graph; + LOG(INFO) << "Initial graph size: " << item.graph.node_size(); + node_map_.reset(new NodeMap(&graph_)); + for (const auto& node : item.fetch) { + nodes_to_preserve_.insert(NodeName(node)); + } + device_.reset(new DeviceSimple()); + TF_RETURN_IF_ERROR(FoldGraph(output)); + LOG(INFO) << "Optimized graph size: " << output->node_size(); + return Status::OK(); +} + +void ConstantFolding::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) { + // Nothing to do for ConstantFolding. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h new file mode 100644 index 00000000000..201e36b853c --- /dev/null +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ + +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { + +const char kConstantFoldingConst[] = "ConstantFolding"; + +// Contant folding optimization for a graph. +class ConstantFolding : public GraphOptimizer { + public: + ConstantFolding() {} + + ~ConstantFolding() override {} + + string name() const override { return "constant folding"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; + + private: + bool IsConst(const NodeDef& node) const; + + bool IsFoldable(const NodeDef& node) const; + + NodeDef CreateNodeDef(const string& name, const TensorValue& tensor); + + Status EvaluateNode(const NodeDef& node, + const gtl::InlinedVector& inputs, + gtl::InlinedVector* output); + + Status EvaluateOneFoldable(const NodeDef& node, + std::vector* outputs); + + Status FoldNode(const NodeDef& node, GraphDef* output); + + Status FoldGraph(GraphDef* output); + + std::unique_ptr device_; + GraphDef graph_; + std::unique_ptr node_map_; + std::set nodes_to_preserve_; + std::set ops_to_preserve_ = {"Save", "SaveV2", "SaveSlices", + "Restore", "RestoreV2", "RestoreSlice"}; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc new file mode 100644 index 00000000000..ab79e741031 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -0,0 +1,148 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class ConstantFoldingTest : public ::testing::Test { + protected: + std::vector EvaluateNodes(const GraphDef& graph, + const std::vector& fetch) { + SessionOptions options; + std::unique_ptr session(NewSession(options)); + TF_CHECK_OK(session->Create(graph)); + RunOptions run_options; + std::vector output_tensors; + TF_CHECK_OK( + session->Run(run_options, {}, fetch, fetch, &output_tensors, nullptr)); + TF_CHECK_OK(session->Close()); + return output_tensors; + } +}; + +TEST_F(ConstantFoldingTest, SimpleFolding) { + // Build a simple graph with a few trivially prunable ops. + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + Output a = ops::Const(s.WithOpName("a"), 1.0f, {1}); + Output b = ops::Const(s.WithOpName("b"), 2.0f, {1}); + Output c = ops::AddN(s.WithOpName("c"), {a, b}); + Output d = ops::AddN(s.WithOpName("d"), {b, c}); + + GrapplerItem item; + item.fetch.push_back("d"); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(5, output.node_size()); + + const NodeDef& new_c = output.node(0); + EXPECT_EQ("ConstantFolding-c", new_c.name()); + EXPECT_EQ("Const", new_c.op()); + + const NodeDef& new_a = output.node(1); + EXPECT_EQ("a", new_a.name()); + + const NodeDef& new_b = output.node(2); + EXPECT_EQ("b", new_b.name()); + + const NodeDef& old_c = output.node(3); + EXPECT_EQ("c", old_c.name()); + + const NodeDef& new_d = output.node(4); + EXPECT_EQ("d", new_d.name()); + EXPECT_EQ("ConstantFolding-c", new_d.input(1)); + + std::vector fetch = {"a", "b", "c", "d"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(4, tensors_expected.size()); + EXPECT_EQ(4, tensors.size()); + for (int i = 0; i < 4; i++) { + test::ExpectTensorEqual(tensors_expected[i], tensors[i]); + } +} + +TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) { + // Build a simple graph with a few trivially prunable ops. + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + Output a = ops::Const(s.WithOpName("a"), 10, {3}); + auto b = ops::Unique(s.WithOpName("b"), {a}); + Output c = ops::Identity(s.WithOpName("c"), {b.y}); + Output d = ops::Identity(s.WithOpName("d"), {b.idx}); + + GrapplerItem item; + item.fetch.push_back("c"); + item.fetch.push_back("d"); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(6, output.node_size()); + + const NodeDef& new_b_0 = output.node(0); + EXPECT_EQ("ConstantFolding-b-0", new_b_0.name()); + EXPECT_EQ("Const", new_b_0.op()); + + const NodeDef& new_b_1 = output.node(1); + EXPECT_EQ("ConstantFolding-b-1", new_b_1.name()); + EXPECT_EQ("Const", new_b_1.op()); + + const NodeDef& new_a = output.node(2); + EXPECT_EQ("a", new_a.name()); + + const NodeDef& new_b = output.node(3); + EXPECT_EQ("b", new_b.name()); + + const NodeDef& new_c = output.node(4); + EXPECT_EQ("c", new_c.name()); + EXPECT_EQ("ConstantFolding-b-0", new_c.input(0)); + + const NodeDef& new_d = output.node(5); + EXPECT_EQ("d", new_d.name()); + EXPECT_EQ("ConstantFolding-b-1", new_d.input(0)); + + std::vector fetch = {"a", "b", "c", "d"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(4, tensors_expected.size()); + EXPECT_EQ(4, tensors.size()); + for (int i = 0; i < 4; i++) { + test::ExpectTensorEqual(tensors_expected[i], tensors[i]); + } +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index b52a5671651..a56961cd954 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -29,7 +29,7 @@ NodeMap::NodeMap(GraphDef* graph) : graph_(graph) { auto node = graph_->mutable_node(i); nodes_.insert(std::make_pair(node->name(), node)); for (const auto& input : node->input()) { - outputs_[input].insert(nodes_[node->name()]); + outputs_[NodeName(input)].insert(nodes_[node->name()]); } } }