Add a constant folding pass to grappler.
Change: 151668925
This commit is contained in:
parent
649ac98246
commit
3db6b68b2f
@ -24,6 +24,42 @@ filegroup(
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant_folding",
|
||||
srcs = ["constant_folding.cc"],
|
||||
hdrs = [
|
||||
"constant_folding.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_optimizer",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "constant_folding_test",
|
||||
srcs = ["constant_folding_test.cc"],
|
||||
deps = [
|
||||
":constant_folding",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_rewriter",
|
||||
srcs = ["graph_rewriter.cc"],
|
||||
|
313
tensorflow/core/grappler/optimizers/constant_folding.cc
Normal file
313
tensorflow/core/grappler/optimizers/constant_folding.cc
Normal file
@ -0,0 +1,313 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
using TensorVector = gtl::InlinedVector<TensorValue, 4>;
|
||||
|
||||
namespace {
|
||||
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
|
||||
public:
|
||||
explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {}
|
||||
~EigenThreadPoolWrapper() override {}
|
||||
void Schedule(std::function<void()> fn) override {
|
||||
pool_->Schedule(std::move(fn));
|
||||
}
|
||||
int NumThreads() const override { return pool_->NumThreads(); }
|
||||
int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
|
||||
|
||||
private:
|
||||
thread::ThreadPool* pool_ = nullptr;
|
||||
};
|
||||
|
||||
class DeviceSimple : public DeviceBase {
|
||||
public:
|
||||
DeviceSimple() : DeviceBase(nullptr) {
|
||||
eigen_worker_threads_.num_threads = 1;
|
||||
eigen_worker_threads_.workers = new thread::ThreadPool(
|
||||
Env::Default(), "constant_folding", eigen_worker_threads_.num_threads);
|
||||
eigen_threadpool_wrapper_.reset(
|
||||
new EigenThreadPoolWrapper(eigen_worker_threads_.workers));
|
||||
eigen_device_.reset(new Eigen::ThreadPoolDevice(
|
||||
eigen_threadpool_wrapper_.get(), eigen_worker_threads_.num_threads));
|
||||
set_tensorflow_cpu_worker_threads(&eigen_worker_threads_);
|
||||
set_eigen_cpu_device(eigen_device_.get());
|
||||
}
|
||||
~DeviceSimple() override {
|
||||
eigen_threadpool_wrapper_.reset();
|
||||
eigen_device_.reset();
|
||||
delete eigen_worker_threads_.workers;
|
||||
}
|
||||
Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) override {
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
tensor_proto.DebugString());
|
||||
}
|
||||
*tensor = parsed;
|
||||
return Status::OK();
|
||||
}
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override {
|
||||
return cpu_allocator();
|
||||
}
|
||||
|
||||
private:
|
||||
DeviceBase::CpuWorkerThreads eigen_worker_threads_;
|
||||
std::unique_ptr<Eigen::ThreadPoolInterface> eigen_threadpool_wrapper_;
|
||||
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
|
||||
};
|
||||
|
||||
Status NumOutputs(const NodeDef& node, int* num_outputs) {
|
||||
const OpDef* op_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
|
||||
if (node.op() == "ConcatOffset") {
|
||||
*num_outputs = node.attr().at("N").i();
|
||||
} else {
|
||||
*num_outputs = op_def->output_arg_size();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool ConstantFolding::IsConst(const NodeDef& node) const {
|
||||
return node.op() == "Const";
|
||||
}
|
||||
|
||||
bool ConstantFolding::IsFoldable(const NodeDef& node) const {
|
||||
DeviceTypeVector device_types;
|
||||
auto status = SupportedDeviceTypesForNode({DeviceType(DEVICE_CPU)}, node,
|
||||
&device_types);
|
||||
if (!status.ok()) {
|
||||
return false;
|
||||
}
|
||||
// Only fold ops with a CPU implementation available.
|
||||
if (device_types[0] != DeviceType(DEVICE_CPU)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops_to_preserve_.find(node.op()) != ops_to_preserve_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't fold stateful ops such as TruncatedNormal.
|
||||
const OpDef* op_def = nullptr;
|
||||
status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
|
||||
if (!status.ok()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (op_def->is_stateful()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (op_def->output_arg_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Folding not applicable to ops with no inputs.
|
||||
if (node.input().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& input : node.input()) {
|
||||
bool is_const = IsConst(*node_map_->GetNode(input));
|
||||
if (!is_const) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
NodeDef ConstantFolding::CreateNodeDef(const string& name,
|
||||
const TensorValue& tensor) {
|
||||
NodeDef node;
|
||||
node.set_name(name);
|
||||
node.set_op("Const");
|
||||
AttrValue attr_output_shape;
|
||||
auto output_shape = attr_output_shape.mutable_list()->add_shape();
|
||||
TensorShapeProto shape;
|
||||
tensor->shape().AsProto(&shape);
|
||||
*output_shape = shape;
|
||||
node.mutable_attr()->insert({"_output_shapes", attr_output_shape});
|
||||
|
||||
AttrValue attr_type;
|
||||
attr_type.set_type(tensor->dtype());
|
||||
node.mutable_attr()->insert({"dtype", attr_type});
|
||||
|
||||
AttrValue attr_tensor;
|
||||
tensor->AsProtoTensorContent(attr_tensor.mutable_tensor());
|
||||
node.mutable_attr()->insert({"value", attr_tensor});
|
||||
return node;
|
||||
}
|
||||
|
||||
Status ConstantFolding::EvaluateNode(const NodeDef& node,
|
||||
const TensorVector& inputs,
|
||||
TensorVector* output) {
|
||||
Status status;
|
||||
auto op_kernel =
|
||||
CreateOpKernel("CPU", device_.get(), device_->GetAllocator({}), node,
|
||||
TF_GRAPH_DEF_VERSION, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
OpKernelContext::Params params;
|
||||
params.device = device_.get();
|
||||
params.frame_iter = FrameAndIter(0, 0);
|
||||
params.inputs = &inputs;
|
||||
params.op_kernel = op_kernel.get();
|
||||
int num_outputs;
|
||||
TF_RETURN_IF_ERROR(NumOutputs(node, &num_outputs));
|
||||
gtl::InlinedVector<AllocatorAttributes, 4> output_attrs;
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
output_attrs.push_back(attr);
|
||||
}
|
||||
params.output_attr_array = output_attrs.data();
|
||||
OpKernelContext op_context(¶ms);
|
||||
op_kernel->Compute(&op_context);
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
output->push_back(op_context.release_output(i));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
|
||||
std::vector<NodeDef>* outputs) {
|
||||
TensorVector inputs;
|
||||
for (const auto& input : node.input()) {
|
||||
TensorVector output;
|
||||
TF_RETURN_IF_ERROR(
|
||||
EvaluateNode(*node_map_->GetNode(input), TensorVector(), &output));
|
||||
inputs.push_back(output[0]);
|
||||
}
|
||||
|
||||
TensorVector output_tensors;
|
||||
TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
|
||||
for (const auto& input : inputs) {
|
||||
delete input.tensor;
|
||||
}
|
||||
if (output_tensors.empty()) {
|
||||
Status(error::INVALID_ARGUMENT, "Expected at least one output.");
|
||||
}
|
||||
for (int i = 0; i < output_tensors.size(); i++) {
|
||||
string node_name = strings::StrCat(
|
||||
AddPrefixToNodeName(node.name(), kConstantFoldingConst));
|
||||
if (output_tensors.size() > 1) {
|
||||
node_name = strings::StrCat(node_name, "-", i);
|
||||
}
|
||||
outputs->push_back(CreateNodeDef(node_name, output_tensors[i]));
|
||||
delete output_tensors[i].tensor;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) {
|
||||
std::vector<NodeDef> const_nodes;
|
||||
TF_RETURN_IF_ERROR(EvaluateOneFoldable(node, &const_nodes));
|
||||
|
||||
auto outputs = node_map_->GetOutputs(node.name());
|
||||
for (const auto& const_node : const_nodes) {
|
||||
NodeDef* added_node = output->add_node();
|
||||
*added_node = const_node;
|
||||
node_map_->AddNode(added_node->name(), added_node);
|
||||
}
|
||||
for (const auto& output : outputs) {
|
||||
for (int i = 0; i < output->input_size(); i++) {
|
||||
int position;
|
||||
string node_name = ParseNodeName(output->input(i), &position);
|
||||
if (node_name == node.name()) {
|
||||
if (position < 0) {
|
||||
*output->mutable_input(i) =
|
||||
strings::StrCat("^", const_nodes[0].name());
|
||||
} else {
|
||||
*output->mutable_input(i) = const_nodes[position].name();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::FoldGraph(GraphDef* output) {
|
||||
std::set<string> processed_nodes;
|
||||
while (1) {
|
||||
int previous_processed = processed_nodes.size();
|
||||
for (const auto& node : graph_.node()) {
|
||||
if (IsFoldable(node) &&
|
||||
processed_nodes.find(node.name()) == processed_nodes.end()) {
|
||||
TF_RETURN_IF_ERROR(FoldNode(node, output));
|
||||
processed_nodes.insert(node.name());
|
||||
}
|
||||
}
|
||||
int current_processed = processed_nodes.size();
|
||||
LOG(INFO) << "Previous number of processed nodes: " << previous_processed
|
||||
<< "; Current number of processed nodes: " << current_processed;
|
||||
if (current_processed == previous_processed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Build the graph after constant folding. Note that we keep all processed
|
||||
// nodes in the graph in case users need to fetch their values.
|
||||
for (const auto& node : graph_.node()) {
|
||||
auto added_node = output->add_node();
|
||||
*added_node = node;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output) {
|
||||
graph_ = item.graph;
|
||||
LOG(INFO) << "Initial graph size: " << item.graph.node_size();
|
||||
node_map_.reset(new NodeMap(&graph_));
|
||||
for (const auto& node : item.fetch) {
|
||||
nodes_to_preserve_.insert(NodeName(node));
|
||||
}
|
||||
device_.reset(new DeviceSimple());
|
||||
TF_RETURN_IF_ERROR(FoldGraph(output));
|
||||
LOG(INFO) << "Optimized graph size: " << output->node_size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ConstantFolding::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) {
|
||||
// Nothing to do for ConstantFolding.
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
73
tensorflow/core/grappler/optimizers/constant_folding.h
Normal file
73
tensorflow/core/grappler/optimizers/constant_folding.h
Normal file
@ -0,0 +1,73 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
|
||||
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
|
||||
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
const char kConstantFoldingConst[] = "ConstantFolding";
|
||||
|
||||
// Contant folding optimization for a graph.
|
||||
class ConstantFolding : public GraphOptimizer {
|
||||
public:
|
||||
ConstantFolding() {}
|
||||
|
||||
~ConstantFolding() override {}
|
||||
|
||||
string name() const override { return "constant folding"; };
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override;
|
||||
|
||||
private:
|
||||
bool IsConst(const NodeDef& node) const;
|
||||
|
||||
bool IsFoldable(const NodeDef& node) const;
|
||||
|
||||
NodeDef CreateNodeDef(const string& name, const TensorValue& tensor);
|
||||
|
||||
Status EvaluateNode(const NodeDef& node,
|
||||
const gtl::InlinedVector<TensorValue, 4>& inputs,
|
||||
gtl::InlinedVector<TensorValue, 4>* output);
|
||||
|
||||
Status EvaluateOneFoldable(const NodeDef& node,
|
||||
std::vector<NodeDef>* outputs);
|
||||
|
||||
Status FoldNode(const NodeDef& node, GraphDef* output);
|
||||
|
||||
Status FoldGraph(GraphDef* output);
|
||||
|
||||
std::unique_ptr<DeviceBase> device_;
|
||||
GraphDef graph_;
|
||||
std::unique_ptr<NodeMap> node_map_;
|
||||
std::set<string> nodes_to_preserve_;
|
||||
std::set<string> ops_to_preserve_ = {"Save", "SaveV2", "SaveSlices",
|
||||
"Restore", "RestoreV2", "RestoreSlice"};
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
|
148
tensorflow/core/grappler/optimizers/constant_folding_test.cc
Normal file
148
tensorflow/core/grappler/optimizers/constant_folding_test.cc
Normal file
@ -0,0 +1,148 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class ConstantFoldingTest : public ::testing::Test {
|
||||
protected:
|
||||
std::vector<Tensor> EvaluateNodes(const GraphDef& graph,
|
||||
const std::vector<string>& fetch) {
|
||||
SessionOptions options;
|
||||
std::unique_ptr<tensorflow::Session> session(NewSession(options));
|
||||
TF_CHECK_OK(session->Create(graph));
|
||||
RunOptions run_options;
|
||||
std::vector<Tensor> output_tensors;
|
||||
TF_CHECK_OK(
|
||||
session->Run(run_options, {}, fetch, fetch, &output_tensors, nullptr));
|
||||
TF_CHECK_OK(session->Close());
|
||||
return output_tensors;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ConstantFoldingTest, SimpleFolding) {
|
||||
// Build a simple graph with a few trivially prunable ops.
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
|
||||
Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
|
||||
Output c = ops::AddN(s.WithOpName("c"), {a, b});
|
||||
Output d = ops::AddN(s.WithOpName("d"), {b, c});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch.push_back("d");
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(5, output.node_size());
|
||||
|
||||
const NodeDef& new_c = output.node(0);
|
||||
EXPECT_EQ("ConstantFolding-c", new_c.name());
|
||||
EXPECT_EQ("Const", new_c.op());
|
||||
|
||||
const NodeDef& new_a = output.node(1);
|
||||
EXPECT_EQ("a", new_a.name());
|
||||
|
||||
const NodeDef& new_b = output.node(2);
|
||||
EXPECT_EQ("b", new_b.name());
|
||||
|
||||
const NodeDef& old_c = output.node(3);
|
||||
EXPECT_EQ("c", old_c.name());
|
||||
|
||||
const NodeDef& new_d = output.node(4);
|
||||
EXPECT_EQ("d", new_d.name());
|
||||
EXPECT_EQ("ConstantFolding-c", new_d.input(1));
|
||||
|
||||
std::vector<string> fetch = {"a", "b", "c", "d"};
|
||||
auto tensors_expected = EvaluateNodes(item.graph, fetch);
|
||||
auto tensors = EvaluateNodes(output, fetch);
|
||||
EXPECT_EQ(4, tensors_expected.size());
|
||||
EXPECT_EQ(4, tensors.size());
|
||||
for (int i = 0; i < 4; i++) {
|
||||
test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
|
||||
// Build a simple graph with a few trivially prunable ops.
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 10, {3});
|
||||
auto b = ops::Unique(s.WithOpName("b"), {a});
|
||||
Output c = ops::Identity(s.WithOpName("c"), {b.y});
|
||||
Output d = ops::Identity(s.WithOpName("d"), {b.idx});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch.push_back("c");
|
||||
item.fetch.push_back("d");
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding fold;
|
||||
GraphDef output;
|
||||
Status status = fold.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(6, output.node_size());
|
||||
|
||||
const NodeDef& new_b_0 = output.node(0);
|
||||
EXPECT_EQ("ConstantFolding-b-0", new_b_0.name());
|
||||
EXPECT_EQ("Const", new_b_0.op());
|
||||
|
||||
const NodeDef& new_b_1 = output.node(1);
|
||||
EXPECT_EQ("ConstantFolding-b-1", new_b_1.name());
|
||||
EXPECT_EQ("Const", new_b_1.op());
|
||||
|
||||
const NodeDef& new_a = output.node(2);
|
||||
EXPECT_EQ("a", new_a.name());
|
||||
|
||||
const NodeDef& new_b = output.node(3);
|
||||
EXPECT_EQ("b", new_b.name());
|
||||
|
||||
const NodeDef& new_c = output.node(4);
|
||||
EXPECT_EQ("c", new_c.name());
|
||||
EXPECT_EQ("ConstantFolding-b-0", new_c.input(0));
|
||||
|
||||
const NodeDef& new_d = output.node(5);
|
||||
EXPECT_EQ("d", new_d.name());
|
||||
EXPECT_EQ("ConstantFolding-b-1", new_d.input(0));
|
||||
|
||||
std::vector<string> fetch = {"a", "b", "c", "d"};
|
||||
auto tensors_expected = EvaluateNodes(item.graph, fetch);
|
||||
auto tensors = EvaluateNodes(output, fetch);
|
||||
EXPECT_EQ(4, tensors_expected.size());
|
||||
EXPECT_EQ(4, tensors.size());
|
||||
for (int i = 0; i < 4; i++) {
|
||||
test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -29,7 +29,7 @@ NodeMap::NodeMap(GraphDef* graph) : graph_(graph) {
|
||||
auto node = graph_->mutable_node(i);
|
||||
nodes_.insert(std::make_pair(node->name(), node));
|
||||
for (const auto& input : node->input()) {
|
||||
outputs_[input].insert(nodes_[node->name()]);
|
||||
outputs_[NodeName(input)].insert(nodes_[node->name()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user