Clean up constant-folder API. No functional changes.
Cleanups: * Remove a deprecated entry point, update callers to use the Status-returning entry point. Rename DoConstantFoldingWithStatus to ConstantFold, now that we have removed the "without Status" API. * Hide an internal function from the header. * Move ConstantFoldingOptions into constant_folding.h Change: 154462306
This commit is contained in:
parent
0172ad0215
commit
163613d82e
@ -122,6 +122,8 @@ void FindConstantFoldableNodes(const Graph* graph,
|
||||
}
|
||||
}
|
||||
|
||||
typedef std::pair<Node*, int> NodeAndOutput;
|
||||
|
||||
// Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
|
||||
// will contain copies of the nodes in 'nodes'. In addition, if there is an edge
|
||||
// going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
|
||||
@ -170,8 +172,10 @@ int64 UniqueConstantId() {
|
||||
return id.fetch_add(1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Replaces the identified Tensor in 'graph' by a 'Const' node with
|
||||
// the value supplied in 'constant'. 'partition_device', if non-null
|
||||
// is the device where the graph executes. Returns true if the
|
||||
// replacement was successful, false otherwise.
|
||||
bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
|
||||
NodeAndOutput tensor, const Tensor& constant) {
|
||||
// Be conservative when replacing a tensor with a constant, when not
|
||||
@ -254,19 +258,11 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
||||
FunctionLibraryRuntime* function_library, Env* env,
|
||||
Device* partition_device, Graph* graph) {
|
||||
bool was_mutated;
|
||||
Status unused_status = DoConstantFoldingWithStatus(
|
||||
opts, function_library, env, partition_device, graph, &was_mutated);
|
||||
return was_mutated;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
|
||||
FunctionLibraryRuntime* function_library,
|
||||
Env* env, Device* partition_device,
|
||||
Graph* graph, bool* was_mutated) {
|
||||
Status ConstantFold(const ConstantFoldingOptions& opts,
|
||||
FunctionLibraryRuntime* function_library, Env* env,
|
||||
Device* partition_device, Graph* graph, bool* was_mutated) {
|
||||
DumpGraph("Before", graph);
|
||||
|
||||
const FunctionLibraryDefinition* flib_def = nullptr;
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -17,12 +17,20 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMMON_RUNTIME_CONSTANT_FOLDING_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Options specific to constant folding optimizations.
|
||||
struct ConstantFoldingOptions {
|
||||
// If "consider" is not a nullptr, then only constant fold a node "n" if
|
||||
// consider(n) returns true.
|
||||
std::function<bool(const Node*)> consider = nullptr;
|
||||
};
|
||||
|
||||
// Perform constant folding optimization on "graph".
|
||||
// Looks for nodes in "graph" that can be completely evaluated statically, i.e.,
|
||||
// that are only dependent on constants. Evaluates those nodes on a CPU device
|
||||
@ -32,25 +40,9 @@ namespace tensorflow {
|
||||
// Sets `was_mutated` to true if and only if "graph" has been mutated.
|
||||
// The status is only set to a non-OK state if an unexpected error is hit
|
||||
// running the graph.
|
||||
Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
|
||||
FunctionLibraryRuntime* function_library,
|
||||
Env* env, Device* partition_device,
|
||||
Graph* graph, bool* was_mutated);
|
||||
|
||||
// Version of the function that doesn't return a Status, for backwards
|
||||
// compatibility.
|
||||
bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
||||
FunctionLibraryRuntime* function_library, Env* env,
|
||||
Device* partition_device, Graph* graph);
|
||||
|
||||
typedef std::pair<Node*, int> NodeAndOutput;
|
||||
|
||||
// Replaces the identified Tensor in 'graph' by a 'Const' node with
|
||||
// the value supplied in 'constant'. 'partition_device', if non-null
|
||||
// is the device where the graph executes. Returns true if the
|
||||
// replacement was successful, false otherwise.
|
||||
bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
|
||||
NodeAndOutput tensor, const Tensor& constant);
|
||||
Status ConstantFold(const ConstantFoldingOptions& opts,
|
||||
FunctionLibraryRuntime* function_library, Env* env,
|
||||
Device* partition_device, Graph* graph, bool* was_mutated);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -104,8 +104,10 @@ TEST_F(ConstantFoldingTest, Basic) {
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(s.ToGraph(&g));
|
||||
|
||||
EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr,
|
||||
Env::Default(), nullptr, &g));
|
||||
bool was_mutated;
|
||||
TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
Node* s1 = index.at("s1");
|
||||
@ -128,7 +130,10 @@ TEST_F(ConstantFoldingTest, ConsiderFunction) {
|
||||
ConstantFoldingOptions opts;
|
||||
// Do not allow constant folding of m2
|
||||
opts.consider = [](const Node* n) { return "m2" != n->name(); };
|
||||
EXPECT_TRUE(DoConstantFolding(opts, nullptr, Env::Default(), nullptr, &g));
|
||||
bool was_mutated;
|
||||
TF_ASSERT_OK(
|
||||
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
Node* s1 = index.at("s1");
|
||||
@ -154,8 +159,10 @@ TEST_F(ConstantFoldingTest, TestNoReplaceAnotherConstant) {
|
||||
TF_ASSERT_OK(s.ToGraph(&g));
|
||||
}
|
||||
|
||||
EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr,
|
||||
Env::Default(), nullptr, &g));
|
||||
bool was_mutated;
|
||||
TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
Node* d = index.at("d");
|
||||
@ -180,8 +187,11 @@ TEST_F(ConstantFoldingTest, TwoOutputs) {
|
||||
TF_ASSERT_OK(s.ToGraph(&g));
|
||||
}
|
||||
|
||||
EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr,
|
||||
Env::Default(), nullptr, &g));
|
||||
bool was_mutated;
|
||||
TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
Node* b0 = index.at("b0");
|
||||
Node* b1 = index.at("b1");
|
||||
@ -209,7 +219,10 @@ TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) {
|
||||
|
||||
ConstantFoldingOptions opts;
|
||||
opts.consider = [](const Node* n) { return "b1_ident" != n->name(); };
|
||||
EXPECT_TRUE(DoConstantFolding(opts, nullptr, Env::Default(), nullptr, &g));
|
||||
bool was_mutated;
|
||||
TF_ASSERT_OK(
|
||||
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
Node* b0 = index.at("b0");
|
||||
@ -243,11 +256,9 @@ TEST_F(ConstantFoldingTest, TestNoReplaceLargeConstant) {
|
||||
|
||||
// The above concat should not have been constant folded.
|
||||
bool was_mutated;
|
||||
Status status =
|
||||
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
|
||||
Env::Default(), nullptr, &g, &was_mutated);
|
||||
TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_FALSE(was_mutated);
|
||||
TF_EXPECT_OK(status);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
|
||||
@ -280,11 +291,9 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
|
||||
|
||||
// The above function call should not have been constant folded.
|
||||
bool was_mutated;
|
||||
Status status =
|
||||
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
|
||||
Env::Default(), nullptr, &g, &was_mutated);
|
||||
TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_FALSE(was_mutated);
|
||||
EXPECT_TRUE(status.ok());
|
||||
}
|
||||
|
||||
REGISTER_OP("ConstantFoldingTestOp").Input("a: int64").Output("b: int64");
|
||||
@ -311,11 +320,9 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
|
||||
|
||||
// The non-CPU op should not have been constant folded.
|
||||
bool was_mutated;
|
||||
Status status =
|
||||
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
|
||||
Env::Default(), nullptr, &g, &was_mutated);
|
||||
TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_FALSE(was_mutated);
|
||||
TF_EXPECT_OK(status);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -392,15 +399,13 @@ TEST_F(ConstantFoldingTest, TestImmutableConst) {
|
||||
TF_ASSERT_OK(root.ToGraph(&g));
|
||||
TestTFEnvironment test_env;
|
||||
bool was_mutated;
|
||||
Status status =
|
||||
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
|
||||
Env::Default(), nullptr, &g, &was_mutated);
|
||||
Status status = ConstantFold(ConstantFoldingOptions{}, nullptr,
|
||||
Env::Default(), nullptr, &g, &was_mutated);
|
||||
EXPECT_FALSE(was_mutated);
|
||||
EXPECT_FALSE(status.ok());
|
||||
status = DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
|
||||
&test_env, nullptr, &g, &was_mutated);
|
||||
TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, &test_env,
|
||||
nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
TF_EXPECT_OK(status);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/optimizer_cse.h"
|
||||
|
||||
@ -56,7 +57,10 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
|
||||
|
||||
if (opts_.do_constant_folding()) {
|
||||
ConstantFoldingOptions cf_opts;
|
||||
if (DoConstantFolding(cf_opts, runtime, env, device, g)) {
|
||||
bool was_mutated;
|
||||
ConstantFold(cf_opts, runtime, env, device, g, &was_mutated)
|
||||
.IgnoreError();
|
||||
if (was_mutated) {
|
||||
RemoveDeadNodes(g);
|
||||
DumpGraph("ConstFolding", g);
|
||||
changed = true;
|
||||
|
@ -24,15 +24,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
class ShapeRefiner;
|
||||
|
||||
// Options specific to constant folding optimizations.
|
||||
//
|
||||
// TODO(ashankar,vrv): This should move to where constant folding is done.
|
||||
struct ConstantFoldingOptions {
|
||||
// If "consider" is not a nullptr, then only constant fold a node "n" if
|
||||
// consider(n) returns true.
|
||||
std::function<bool(const Node*)> consider = nullptr;
|
||||
};
|
||||
|
||||
// Construct a Graph *g out of a GraphDef gdef. Returns non-OK on
|
||||
// error, in which case *g is left in an incomplete state.
|
||||
//
|
||||
|
@ -152,9 +152,9 @@ Status FoldConstants(const GraphDef& input_graph_def,
|
||||
&input_graph, context.input_names, context.output_names, {},
|
||||
device_attributes, false /* use_function_convention */, &metadata));
|
||||
bool was_mutated;
|
||||
TF_RETURN_IF_ERROR(DoConstantFoldingWithStatus(
|
||||
ConstantFoldingOptions(), nullptr, Env::Default(), nullptr, &input_graph,
|
||||
&was_mutated));
|
||||
TF_RETURN_IF_ERROR(ConstantFold(ConstantFoldingOptions(), nullptr,
|
||||
Env::Default(), nullptr, &input_graph,
|
||||
&was_mutated));
|
||||
GraphDef folded_graph_def;
|
||||
input_graph.ToGraphDef(&folded_graph_def);
|
||||
GraphDef send_recvs_replaced;
|
||||
|
Loading…
Reference in New Issue
Block a user