Clean up constant-folder API. No functional changes.

Cleanups:
* Remove a deprecated entry point, update callers to use the Status-returning entry point. Rename DoConstantFoldingWithStatus to ConstantFold, now that we have removed the "without Status" API.
* Hide an internal function from the header.
* Move ConstantFoldingOptions into constant_folding.h
Change: 154462306
This commit is contained in:
Peter Hawkins 2017-04-27 12:32:09 -08:00 committed by TensorFlower Gardener
parent 0172ad0215
commit 163613d82e
6 changed files with 62 additions and 74 deletions

View File

@ -122,6 +122,8 @@ void FindConstantFoldableNodes(const Graph* graph,
}
}
typedef std::pair<Node*, int> NodeAndOutput;
// Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
// will contain copies of the nodes in 'nodes'. In addition, if there is an edge
// going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
@ -170,8 +172,10 @@ int64 UniqueConstantId() {
return id.fetch_add(1);
}
} // namespace
// Replaces the identified Tensor in 'graph' by a 'Const' node with
// the value supplied in 'constant'. 'partition_device', if non-null
// is the device where the graph executes. Returns true if the
// replacement was successful, false otherwise.
bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
NodeAndOutput tensor, const Tensor& constant) {
// Be conservative when replacing a tensor with a constant, when not
@ -254,19 +258,11 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
return true;
}
bool DoConstantFolding(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library, Env* env,
Device* partition_device, Graph* graph) {
bool was_mutated;
Status unused_status = DoConstantFoldingWithStatus(
opts, function_library, env, partition_device, graph, &was_mutated);
return was_mutated;
}
} // namespace
Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library,
Env* env, Device* partition_device,
Graph* graph, bool* was_mutated) {
Status ConstantFold(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library, Env* env,
Device* partition_device, Graph* graph, bool* was_mutated) {
DumpGraph("Before", graph);
const FunctionLibraryDefinition* flib_def = nullptr;

View File

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -17,12 +17,20 @@ limitations under the License.
#define TENSORFLOW_COMMON_RUNTIME_CONSTANT_FOLDING_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
// Options specific to constant folding optimizations.
struct ConstantFoldingOptions {
// If "consider" is not a nullptr, then only constant fold a node "n" if
// consider(n) returns true.
std::function<bool(const Node*)> consider = nullptr;
};
// Perform constant folding optimization on "graph".
// Looks for nodes in "graph" that can be completely evaluated statically, i.e.,
// that are only dependent on constants. Evaluates those nodes on a CPU device
@ -32,25 +40,9 @@ namespace tensorflow {
// Sets `was_mutated` to true if and only if "graph" has been mutated.
// The status is only set to a non-OK state if an unexpected error is hit
// running the graph.
Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library,
Env* env, Device* partition_device,
Graph* graph, bool* was_mutated);
// Version of the function that doesn't return a Status, for backwards
// compatibility.
bool DoConstantFolding(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library, Env* env,
Device* partition_device, Graph* graph);
typedef std::pair<Node*, int> NodeAndOutput;
// Replaces the identified Tensor in 'graph' by a 'Const' node with
// the value supplied in 'constant'. 'partition_device', if non-null
// is the device where the graph executes. Returns true if the
// replacement was successful, false otherwise.
bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
NodeAndOutput tensor, const Tensor& constant);
Status ConstantFold(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library, Env* env,
Device* partition_device, Graph* graph, bool* was_mutated);
} // namespace tensorflow

View File

@ -104,8 +104,10 @@ TEST_F(ConstantFoldingTest, Basic) {
Graph g(OpRegistry::Global());
TF_ASSERT_OK(s.ToGraph(&g));
EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr,
Env::Default(), nullptr, &g));
bool was_mutated;
TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
Node* s1 = index.at("s1");
@ -128,7 +130,10 @@ TEST_F(ConstantFoldingTest, ConsiderFunction) {
ConstantFoldingOptions opts;
// Do not allow constant folding of m2
opts.consider = [](const Node* n) { return "m2" != n->name(); };
EXPECT_TRUE(DoConstantFolding(opts, nullptr, Env::Default(), nullptr, &g));
bool was_mutated;
TF_ASSERT_OK(
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
Node* s1 = index.at("s1");
@ -154,8 +159,10 @@ TEST_F(ConstantFoldingTest, TestNoReplaceAnotherConstant) {
TF_ASSERT_OK(s.ToGraph(&g));
}
EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr,
Env::Default(), nullptr, &g));
bool was_mutated;
TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
Node* d = index.at("d");
@ -180,8 +187,11 @@ TEST_F(ConstantFoldingTest, TwoOutputs) {
TF_ASSERT_OK(s.ToGraph(&g));
}
EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr,
Env::Default(), nullptr, &g));
bool was_mutated;
TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
Node* b0 = index.at("b0");
Node* b1 = index.at("b1");
@ -209,7 +219,10 @@ TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) {
ConstantFoldingOptions opts;
opts.consider = [](const Node* n) { return "b1_ident" != n->name(); };
EXPECT_TRUE(DoConstantFolding(opts, nullptr, Env::Default(), nullptr, &g));
bool was_mutated;
TF_ASSERT_OK(
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
Node* b0 = index.at("b0");
@ -243,11 +256,9 @@ TEST_F(ConstantFoldingTest, TestNoReplaceLargeConstant) {
// The above concat should not have been constant folded.
bool was_mutated;
Status status =
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
Env::Default(), nullptr, &g, &was_mutated);
TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
nullptr, &g, &was_mutated));
EXPECT_FALSE(was_mutated);
TF_EXPECT_OK(status);
}
TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
@ -280,11 +291,9 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
// The above function call should not have been constant folded.
bool was_mutated;
Status status =
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
Env::Default(), nullptr, &g, &was_mutated);
TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
nullptr, &g, &was_mutated));
EXPECT_FALSE(was_mutated);
EXPECT_TRUE(status.ok());
}
REGISTER_OP("ConstantFoldingTestOp").Input("a: int64").Output("b: int64");
@ -311,11 +320,9 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
// The non-CPU op should not have been constant folded.
bool was_mutated;
Status status =
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
Env::Default(), nullptr, &g, &was_mutated);
TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
nullptr, &g, &was_mutated));
EXPECT_FALSE(was_mutated);
TF_EXPECT_OK(status);
}
namespace {
@ -392,15 +399,13 @@ TEST_F(ConstantFoldingTest, TestImmutableConst) {
TF_ASSERT_OK(root.ToGraph(&g));
TestTFEnvironment test_env;
bool was_mutated;
Status status =
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
Env::Default(), nullptr, &g, &was_mutated);
Status status = ConstantFold(ConstantFoldingOptions{}, nullptr,
Env::Default(), nullptr, &g, &was_mutated);
EXPECT_FALSE(was_mutated);
EXPECT_FALSE(status.ok());
status = DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
&test_env, nullptr, &g, &was_mutated);
TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, &test_env,
nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
TF_EXPECT_OK(status);
}
} // namespace

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/optimizer_cse.h"
@ -56,7 +57,10 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
if (opts_.do_constant_folding()) {
ConstantFoldingOptions cf_opts;
if (DoConstantFolding(cf_opts, runtime, env, device, g)) {
bool was_mutated;
ConstantFold(cf_opts, runtime, env, device, g, &was_mutated)
.IgnoreError();
if (was_mutated) {
RemoveDeadNodes(g);
DumpGraph("ConstFolding", g);
changed = true;

View File

@ -24,15 +24,6 @@ limitations under the License.
namespace tensorflow {
class ShapeRefiner;
// Options specific to constant folding optimizations.
//
// TODO(ashankar,vrv): This should move to where constant folding is done.
struct ConstantFoldingOptions {
// If "consider" is not a nullptr, then only constant fold a node "n" if
// consider(n) returns true.
std::function<bool(const Node*)> consider = nullptr;
};
// Construct a Graph *g out of a GraphDef gdef. Returns non-OK on
// error, in which case *g is left in an incomplete state.
//

View File

@ -152,9 +152,9 @@ Status FoldConstants(const GraphDef& input_graph_def,
&input_graph, context.input_names, context.output_names, {},
device_attributes, false /* use_function_convention */, &metadata));
bool was_mutated;
TF_RETURN_IF_ERROR(DoConstantFoldingWithStatus(
ConstantFoldingOptions(), nullptr, Env::Default(), nullptr, &input_graph,
&was_mutated));
TF_RETURN_IF_ERROR(ConstantFold(ConstantFoldingOptions(), nullptr,
Env::Default(), nullptr, &input_graph,
&was_mutated));
GraphDef folded_graph_def;
input_graph.ToGraphDef(&folded_graph_def);
GraphDef send_recvs_replaced;