Add the option of including Shape, ShapeN, Size and Rank in the standard TensorFlow constant propagation pass, when the inputs to those Ops have sufficiently known static shape.

PiperOrigin-RevId: 163762750
This commit is contained in:
A. Unique TensorFlower 2017-07-31 16:06:05 -07:00 committed by TensorFlower Gardener
parent 8b1365bb40
commit 9e7875437f
10 changed files with 562 additions and 89 deletions

View File

@ -161,7 +161,7 @@ Status XlaCompiler::CompileFunction(
opts.set_do_constant_folding(true);
GraphOptimizer optimizer(opts);
optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(),
/*device=*/nullptr, &graph);
/*device=*/nullptr, &graph, /*shape_map=*/nullptr);
VLOG(1) << "====================================================";
TF_RETURN_IF_ERROR(

View File

@ -43,11 +43,182 @@ namespace tensorflow {
namespace {
bool IsConstantFoldable(const Node* n,
const std::function<bool(const Node*)>& consider) {
// Test to see if the Op is one that turns into a constant when its
// inputs' shapes are known.
bool IsShapeOp(const Node* n) {
const auto& ts = n->type_string();
return ts == "Shape" || ts == "ShapeN" || ts == "Rank" || ts == "Size";
}
// Reads the partially-known shape of each of n's inputs from shape_map, and
// stores it to input_shapes. Returns false if any input does not have a shape
// in shape_map.
bool ReadPartialShapesFromShapeMap(
const Node* n,
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
shape_map,
std::vector<PartialTensorShape>* input_shapes) {
CHECK(shape_map != nullptr);
for (const Edge* in : n->in_edges()) {
// Don't need to check if incoming control edges have known shapes.
if (in->IsControlEdge()) continue;
if (shape_map->count(in->src()) == 0) {
// One of n's inputs doesn't have known shapes, so don't replace n.
return false;
}
const auto& known_shape = shape_map->at(in->src());
CHECK_GT(known_shape.size(), in->src_output());
input_shapes->push_back(known_shape[in->src_output()]);
}
return true;
}
// If all of n's inputs have fully-defined shapes, inserts those shapes as a
// vector of Tensors in the shape_replacement_map.
bool MaybeReplaceShapeOrShapeNOp(
const Node* n, const std::vector<PartialTensorShape>& input_shapes,
std::unordered_map<const Node*, std::vector<Tensor>>*
shape_replacement_map) {
std::vector<Tensor> defined_shape;
for (const auto& shape : input_shapes) {
if (!shape.IsFullyDefined()) {
return false;
}
const int rank = shape.dims();
DataType op_type = n->output_type(0);
Tensor t(op_type, TensorShape({rank}));
if (op_type == DT_INT64) {
auto vec = t.vec<int64>();
for (int i = 0; i < rank; ++i) {
vec(i) = shape.dim_size(i);
}
} else {
CHECK(op_type == DT_INT32);
auto vec = t.vec<int32>();
for (int i = 0; i < rank; ++i) {
if (shape.dim_size(i) > INT_MAX) {
VLOG(1) << "Node " << n->name() << " has input shape dimension " << i
<< " of " << shape.dim_size(i) << " but type INT32 "
<< " so not replacing as constant: this will trigger a "
"runtime error later.";
return false;
}
vec(i) = static_cast<int32>(shape.dim_size(i));
}
}
defined_shape.push_back(t);
}
// All the inputs had known shapes so we can replace the node by constants
// later in the rewrite.
shape_replacement_map->insert({n, defined_shape});
return true;
}
// If n's input has defined rank, inserts that rank as a Tensor in the
// shape_replacement_map.
bool MaybeReplaceRankOp(const Node* n,
const std::vector<PartialTensorShape>& input_shapes,
std::unordered_map<const Node*, std::vector<Tensor>>*
shape_replacement_map) {
CHECK_EQ(input_shapes.size(), 1);
if (input_shapes[0].unknown_rank()) {
return false;
}
Tensor t(DT_INT32, TensorShape({}));
t.scalar<int32>()() = input_shapes[0].dims();
shape_replacement_map->insert({n, {t}});
return true;
}
// If n's input has defined size, inserts that size as a Tensor in the
// shape_replacement_map.
bool MaybeReplaceSizeOp(const Node* n,
const std::vector<PartialTensorShape>& input_shapes,
std::unordered_map<const Node*, std::vector<Tensor>>*
shape_replacement_map) {
CHECK_EQ(input_shapes.size(), 1);
if (!input_shapes[0].IsFullyDefined()) {
return false;
}
DataType op_type = n->output_type(0);
Tensor t(op_type, TensorShape({}));
int64 size = input_shapes[0].num_elements();
if (op_type == DT_INT64) {
t.scalar<int64>()() = size;
} else {
CHECK(op_type == DT_INT32);
if (size > INT_MAX) {
VLOG(1) << "Node " << n->name() << " has input shape size " << size
<< " but type INT32 "
<< " so not replacing as constant: this will trigger a runtime "
"error later.";
return false;
}
t.scalar<int32>()() = static_cast<int32>(size);
}
shape_replacement_map->insert({n, {t}});
return true;
}
// If n is a shape Op (Shape, ShapeN, Rank, or Size) and its inputs have their
// shapes specified in shape_map, then adds to shape_replacement_map a mapping
// from n to a vector of Tensors, where Tensor k is the (statically known) value
// on n's kth output edge. shape_replacement_map has an entry for n iff
// MaybeReplaceShapeOp returns true, so it's valid to use
// shape_replacement_map->count(n) as a test to see if n is a shape op that can
// be replaced.
bool MaybeReplaceShapeOp(
const Node* n,
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
shape_map,
std::unordered_map<const Node*, std::vector<Tensor>>*
shape_replacement_map) {
if (shape_map == nullptr || !IsShapeOp(n)) {
return false;
}
// input_shapes will contain the shapes of each of n's inputs.
std::vector<PartialTensorShape> input_shapes;
if (!ReadPartialShapesFromShapeMap(n, shape_map, &input_shapes)) {
return false;
}
const auto& ts = n->type_string();
if (ts == "Shape" || ts == "ShapeN") {
if (!MaybeReplaceShapeOrShapeNOp(n, input_shapes, shape_replacement_map)) {
return false;
}
} else if (ts == "Rank") {
if (!MaybeReplaceRankOp(n, input_shapes, shape_replacement_map)) {
return false;
}
} else {
CHECK_EQ(ts, "Size");
if (!MaybeReplaceSizeOp(n, input_shapes, shape_replacement_map)) {
return false;
}
}
return true;
}
// Returns true if n can be evaluated as constant. shape_map maps from
// nodes to the partially-known shapes of their outputs. consider if
// non-null returns a bool indicating whether a given (non-Const,
// non-Shape) node is eligible to be
// constant-propagated. shape_replacement_map is filled in with a
// vector of constant output tensors for constant-foldable shape nodes
// (Shape, ShapeN, Size, or Rank).
bool IsConstantFoldable(
const Node* n,
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& consider,
std::unordered_map<const Node*, std::vector<Tensor>>*
shape_replacement_map) {
if (n->IsConstant()) {
return true;
}
if (MaybeReplaceShapeOp(n, shape_map, shape_replacement_map)) {
return true;
}
if (n->op_def().is_stateful()) {
return false;
}
@ -82,56 +253,81 @@ bool IsConstantFoldable(const Node* n,
return true;
}
// If n is eligible for constant-folding, adds it to nodes, and places its
// control dependencies and those transitively of its constant-foldable inputs
// into constant_control_deps. If n is a constant-foldable shape node (Shape,
// ShapeN, Rank, or Size), also puts its outputs into shape_replacement_map.
void ConsiderConstantFoldableNode(
Node* n, const ConstantFoldingOptions& opts, std::vector<Node*>* nodes,
std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map,
bool* internal_node_inserted) {
if (IsConstantFoldable(n, opts.shape_map, opts.consider,
shape_replacement_map)) {
// A node is constant provided all of its non-control incoming Tensors come
// from constant nodes, or it's a shape Op with statically known inputs in
// which case it is placed in shape_replacement_map.
//
// We allow control dependencies from non-constant nodes to constant nodes,
// but to preserve the graph structure we must transfer the control
// dependency onto any constant replacement.
bool all_parents_constant = true;
for (const Edge* in : n->in_edges()) {
// Allows non-constant -> constant control edges.
if (!in->IsControlEdge() &&
constant_control_deps->count(in->src()) == 0) {
all_parents_constant = false;
break;
}
}
if (all_parents_constant || shape_replacement_map->count(n) != 0) {
gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
for (const Edge* e : n->in_edges()) {
if (constant_control_deps->count(e->src()) == 0) {
// This branch is taken if the incoming edge is a control dependency,
// in which case we want to add it to the dependencies being
// accumulated for this node, or the incoming edge is not
// constant. The latter may happen when n is a shape node and the
// source has known shape. In that case add a control dependency from
// the source node, since there was previously a data dependency and
// we want to preserve sequencing constraints.
if (!e->src()->IsSource()) {
control_deps.insert(e->src());
}
} else {
// If the parent has been accumulating control dependencies, add all
// of its transitive control deps.
const gtl::FlatSet<Node*>& parent_deps =
(*constant_control_deps)[e->src()];
control_deps.insert(parent_deps.begin(), parent_deps.end());
}
}
nodes->push_back(n);
if (!n->IsConstant()) {
*internal_node_inserted = true;
}
}
}
}
// Returns the constant foldable nodes in `nodes` in topological order.
// Populates `constant_control_deps` with the non-constant control dependencies
// of each constant node.
void FindConstantFoldableNodes(
const Graph* graph, ConstantFoldingOptions opts, std::vector<Node*>* nodes,
std::unordered_map<const Node*, gtl::FlatSet<Node*>>*
constant_control_deps) {
const Graph* graph, const ConstantFoldingOptions& opts,
std::vector<Node*>* nodes,
std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
std::unordered_map<const Node*, std::vector<Tensor>>*
shape_replacement_map) {
bool internal_node_inserted = false;
// Walk the nodes in data flow order
ReverseDFS(
*graph, nullptr,
[nodes, constant_control_deps, &internal_node_inserted, opts](Node* n) {
if (IsConstantFoldable(n, opts.consider)) {
// A node is constant provided all of its non-control
// incoming Tensors come from constant nodes.
//
// We allow control dependencies from non-constant nodes to constant
// nodes, but to preserve the graph structure we must transfer the
// control dependency onto any constant replacement.
bool all_parents_constant = true;
for (const Edge* in : n->in_edges()) {
// Allows non-constant -> constant control edges.
if (!in->IsControlEdge() &&
constant_control_deps->count(in->src()) == 0) {
all_parents_constant = false;
break;
}
}
if (all_parents_constant) {
gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
for (const Edge* e : n->in_edges()) {
if (constant_control_deps->count(e->src()) == 0) {
if (!e->src()->IsSource()) {
control_deps.insert(e->src());
}
} else {
// If the parent is constant, add all of its transitive control
// deps.
const gtl::FlatSet<Node*>& parent_deps =
(*constant_control_deps)[e->src()];
control_deps.insert(parent_deps.begin(), parent_deps.end());
}
}
nodes->push_back(n);
if (!n->IsConstant()) {
internal_node_inserted = true;
}
}
}
});
// Walk the nodes in data flow order.
ReverseDFS(*graph, nullptr,
[nodes, constant_control_deps, shape_replacement_map,
&internal_node_inserted, &opts](Node* n) {
ConsiderConstantFoldableNode(
n, opts, nodes, constant_control_deps, shape_replacement_map,
&internal_node_inserted);
});
// If we have inserted just leaf level nodes, then there is nothing to fold.
if (!internal_node_inserted) {
nodes->clear();
@ -141,31 +337,93 @@ void FindConstantFoldableNodes(
typedef std::pair<Node*, int> NodeAndOutput;
int64 UniqueConstantId() {
static std::atomic_int_fast64_t id;
return id.fetch_add(1);
}
// Adds n to constant_graph which is being built up for subsequent evaluation of
// constant propagation. node_map is the mapping of nodes in the original graph
// to nodes in the constant graph. The value of an entry in node_map is a vector
// of nodes because a ShapeN node in the original graph is replaced by a vector
// of Constant nodes in the constant graph.
void AddNodeToConstantGraph(
Node* n, std::unordered_map<Node*, std::vector<Node*>>* node_map,
Graph* constant_graph) {
std::vector<Node*>& added = (*node_map)[n];
added.push_back(constant_graph->CopyNode(n));
for (const Edge* in_edge : n->in_edges()) {
// Don't copy control edges to the constant graph.
if (!in_edge->IsControlEdge()) {
Node* in = in_edge->src();
auto it = node_map->find(in);
CHECK(it != node_map->end())
<< n->DebugString() << " <-" << in->DebugString();
if (it->second.size() == 1) {
constant_graph->AddEdge(it->second[0], in_edge->src_output(), added[0],
in_edge->dst_input());
} else {
// The original source node had multiple outputs and was replaced by a
// vector of constants, so the edge comes from the 0th output of the kth
// added constant, rather than the kth output of the added node as in
// the standard case above.
constant_graph->AddEdge(it->second[in_edge->src_output()], 0, added[0],
in_edge->dst_input());
}
}
}
}
// Replaces constant-foldable shape node n by a vector of constants in
// constant_graph, which is being built up for subsequent evaluation of constant
// propagation. node_map is the mapping of nodes in the original graph to nodes
// in the constant graph. The value of an entry in node_map is a vector of nodes
// because a ShapeN node in the original graph is replaced by a vector of
// Constant nodes in the constant graph.
void AddShapeNodeToConstantGraph(
Node* n,
const std::unordered_map<const Node*, std::vector<Tensor>>&
shape_replacement_map,
std::unordered_map<Node*, std::vector<Node*>>* node_map,
Graph* constant_graph) {
std::vector<Node*>& added = (*node_map)[n];
const string& node_name = n->name();
for (const Tensor& t : shape_replacement_map.at(n)) {
auto builder =
NodeDefBuilder(strings::StrCat(constant_graph->NewName(node_name),
"__cf__", UniqueConstantId()),
"Const")
.Attr("dtype", t.dtype())
.Attr("value", t);
NodeDef def;
CHECK(builder.Finalize(&def).ok());
Node* constant_node;
CHECK(NodeBuilder(builder).Finalize(constant_graph, &constant_node).ok());
added.push_back(constant_node);
}
// Don't copy incoming edges to shape nodes that are being replaced.
}
// Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
// will contain copies of the nodes in 'nodes'. In addition, if there is an edge
// going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
// 'nodes', then 'tensors_to_fetch' will contain the mapping from the
// corresponding copy of 'n' and the edge number in 'g' to 'n'.
Graph* GetConstantGraph(const Graph* orig_graph,
const std::vector<Node*>& nodes,
std::map<NodeAndOutput, Node*>* tensors_to_fetch) {
Graph* GetConstantGraph(
const Graph* orig_graph, const std::vector<Node*>& nodes,
const std::unordered_map<const Node*, std::vector<Tensor>>&
shape_replacement_map,
std::map<NodeAndOutput, Node*>* tensors_to_fetch) {
Graph* constant_graph = new Graph(orig_graph->op_registry());
std::unordered_map<Node*, Node*> node_map;
node_map[orig_graph->source_node()] = constant_graph->source_node();
node_map[orig_graph->sink_node()] = constant_graph->sink_node();
std::unordered_map<Node*, std::vector<Node*>> node_map;
node_map[orig_graph->source_node()] = {constant_graph->source_node()};
node_map[orig_graph->sink_node()] = {constant_graph->sink_node()};
for (Node* n : nodes) {
Node* added = constant_graph->CopyNode(n);
node_map[n] = added;
for (const Edge* in_edge : n->in_edges()) {
// Don't copy control edges to the constant graph.
if (!in_edge->IsControlEdge()) {
Node* in = in_edge->src();
auto it = node_map.find(in);
CHECK(it != node_map.end())
<< n->DebugString() << " <-" << in->DebugString();
constant_graph->AddEdge(it->second, in_edge->src_output(), added,
in_edge->dst_input());
}
if (shape_replacement_map.count(n) == 0) {
AddNodeToConstantGraph(n, &node_map, constant_graph);
} else {
AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map,
constant_graph);
}
}
@ -173,8 +431,19 @@ Graph* GetConstantGraph(const Graph* orig_graph,
for (const Edge* out_edge : added_nodes.first->out_edges()) {
if (node_map.count(out_edge->dst()) == 0) {
if (out_edge->IsControlEdge()) continue;
tensors_to_fetch->insert(
{{added_nodes.second, out_edge->src_output()}, added_nodes.first});
if (added_nodes.second.size() == 1) {
tensors_to_fetch->insert(
{{added_nodes.second[0], out_edge->src_output()},
added_nodes.first});
} else {
// The node had multiple outputs and was replaced by a
// vector of constants, so the NodeAndOutput is the 0th
// output of the kth added constant, rather than the kth
// output of the added node as in the standard case above.
tensors_to_fetch->insert(
{{added_nodes.second[out_edge->src_output()], 0},
added_nodes.first});
}
}
}
}
@ -182,11 +451,6 @@ Graph* GetConstantGraph(const Graph* orig_graph,
return constant_graph;
}
int64 UniqueConstantId() {
static std::atomic_int_fast64_t id;
return id.fetch_add(1);
}
// Replaces the identified Tensor in 'graph' by a 'Const' node with
// the value supplied in 'constant'. 'partition_device', if non-null
// is the device where the graph executes. Returns true if the
@ -291,8 +555,9 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
std::vector<Node*> constant_foldable_nodes;
std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps;
std::unordered_map<const Node*, std::vector<Tensor>> shape_replacement_map;
FindConstantFoldableNodes(graph, opts, &constant_foldable_nodes,
&constant_control_deps);
&constant_control_deps, &shape_replacement_map);
if (constant_foldable_nodes.empty()) {
VLOG(1) << "No constant foldable nodes found";
*was_mutated = false;
@ -302,7 +567,8 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
std::map<NodeAndOutput, Node*> tensors_to_fetch;
std::unique_ptr<Graph> constant_graph(
GetConstantGraph(graph, constant_foldable_nodes, &tensors_to_fetch));
GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map,
&tensors_to_fetch));
DumpGraph("Constant graph", constant_graph.get());
if (tensors_to_fetch.empty()) {
@ -337,7 +603,6 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
if (!s.ok()) {
VLOG(1) << "Could not fetch constants: " << s;
*was_mutated = false;
// This is not an error, so return the status as OK.
return s;
}

View File

@ -29,6 +29,11 @@ struct ConstantFoldingOptions {
// If "consider" is not a nullptr, then only constant fold a node "n" if
// consider(n) returns true.
std::function<bool(const Node*)> consider = nullptr;
// If shape_map is not a nullptr, it is a map from node n to a
// vector of the (potentially partially-known) shapes of its
// outputs.
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
shape_map; // not owned
};
// Perform constant folding optimization on "graph".

View File

@ -363,6 +363,197 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
}
}
TEST_F(ConstantFoldingTest, SimpleShapeKnown) {
Graph g(OpRegistry::Global());
{
Scope s = Scope::NewRootScope();
Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0",
"sender", 0, "receiver");
auto shape = ops::Shape(s.WithOpName("shape"), recv0);
Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1",
"sender", 0, "receiver");
auto shape_n = ops::ShapeN(s.WithOpName("shape_n"), {recv0, recv1});
auto rank = ops::Rank(s.WithOpName("rank"), recv0);
auto size = ops::Size(s.WithOpName("size"), recv1);
auto recv2 = ops::_Recv(s.WithOpName("recv2"), DT_FLOAT, "recv2", "sender",
0, "receiver");
auto c = ops::Const<int>(s.WithControlDependencies(recv2), 3);
auto add0 = ops::Add(s.WithControlDependencies(c), rank, size);
auto add1 = ops::Add(s, shape, shape_n[0]);
auto add2 = ops::Add(s, shape_n[1], shape_n[1]);
auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0,
"receiver");
auto send1 = ops::_Send(s.WithOpName("send1"), add1, "send1", "sender", 0,
"receiver");
auto send2 = ops::_Send(s.WithOpName("send2"), add2, "send2", "sender", 0,
"receiver");
TF_ASSERT_OK(s.ToGraph(&g));
}
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
Node* recv0 = orig_index.at("recv0");
Node* recv1 = orig_index.at("recv1");
PartialTensorShape ps0;
int r0_dims[] = {1, 2};
TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0));
PartialTensorShape ps1;
int r1_dims[] = {2, 3, 4};
TF_EXPECT_OK(PartialTensorShape::MakePartialShape<int>(r1_dims, 3, &ps1));
std::unordered_map<const Node*, std::vector<PartialTensorShape>> map;
map[recv0].push_back(ps0);
map[recv1].push_back(ps1);
ConstantFoldingOptions opts;
opts.shape_map = &map;
bool was_mutated;
TF_EXPECT_OK(
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
Node* recv2 = index.at("recv2");
Node* send0 = index.at("send0");
Node* send1 = index.at("send1");
Node* send2 = index.at("send2");
ASSERT_EQ(1, send0->num_inputs());
Node* cf0 = *(send0->in_nodes().begin());
ExpectNodeEqual<int>(cf0, {26}, {});
ASSERT_EQ(1, send1->num_inputs());
Node* cf1 = *(send1->in_nodes().begin());
ExpectNodeEqual<int>(cf1, {2, 4}, {2});
ASSERT_EQ(1, send2->num_inputs());
Node* cf2 = *(send2->in_nodes().begin());
ExpectNodeEqual<int>(cf2, {4, 6, 8}, {3});
ASSERT_EQ(3, cf0->in_edges().size());
for (const Edge* e : cf0->in_edges()) {
EXPECT_TRUE(e->IsControlEdge());
EXPECT_TRUE(e->src() == recv0 || e->src() == recv1 || e->src() == recv2)
<< e->src()->name();
}
ASSERT_EQ(2, cf1->in_edges().size());
for (const Edge* e : cf1->in_edges()) {
EXPECT_TRUE(e->IsControlEdge());
EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name();
}
ASSERT_EQ(2, cf2->in_edges().size());
for (const Edge* e : cf2->in_edges()) {
EXPECT_TRUE(e->IsControlEdge());
EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name();
}
}
TEST_F(ConstantFoldingTest, PartialShape) {
Graph g(OpRegistry::Global());
{
Scope s = Scope::NewRootScope();
Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0",
"sender", 0, "receiver");
Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1",
"sender", 0, "receiver");
auto shape = ops::Shape(s.WithOpName("shape"), recv0);
auto rank0 = ops::Rank(s.WithOpName("rank0"), recv0);
auto rank1 = ops::Rank(s.WithOpName("rank1"), recv1);
auto size = ops::Size(s.WithOpName("size"), recv0);
auto send0 = ops::_Send(s.WithOpName("send0"), rank0, "send0", "sender", 0,
"receiver");
auto send1 = ops::_Send(s.WithOpName("send1"), shape, "send1", "sender", 0,
"receiver");
auto send2 = ops::_Send(s.WithOpName("send2"), size, "send2", "sender", 0,
"receiver");
auto send3 = ops::_Send(s.WithOpName("send3"), rank1, "send3", "sender", 0,
"receiver");
TF_ASSERT_OK(s.ToGraph(&g));
}
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
Node* recv0 = orig_index.at("recv0");
Node* recv1 = orig_index.at("recv1");
PartialTensorShape ps0;
int r0_dims[] = {-1, -1};
TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0));
PartialTensorShape ps1;
std::unordered_map<const Node*, std::vector<PartialTensorShape>> map;
map[recv0].push_back(ps0);
map[recv1].push_back(ps1);
ConstantFoldingOptions opts;
opts.shape_map = &map;
bool was_mutated;
TF_EXPECT_OK(
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
Node* shape = index.at("shape");
Node* size = index.at("size");
Node* rank1 = index.at("rank1");
Node* send0 = index.at("send0");
Node* send1 = index.at("send1");
Node* send2 = index.at("send2");
Node* send3 = index.at("send3");
ASSERT_EQ(1, send0->num_inputs());
Node* cf0 = *(send0->in_nodes().begin());
ExpectNodeEqual<int>(cf0, {2}, {});
ASSERT_EQ(1, send1->num_inputs());
Node* ncf1 = *(send1->in_nodes().begin());
EXPECT_EQ(ncf1, shape);
ASSERT_EQ(1, send2->num_inputs());
Node* ncf2 = *(send2->in_nodes().begin());
EXPECT_EQ(ncf2, size);
ASSERT_EQ(1, send3->num_inputs());
Node* ncf3 = *(send3->in_nodes().begin());
EXPECT_EQ(ncf3, rank1);
}
TEST_F(ConstantFoldingTest, ConstShapeKnown) {
Graph g(OpRegistry::Global());
{
Scope s = Scope::NewRootScope();
auto recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0", "sender",
0, "receiver");
auto c0 =
ops::Const<int>(s.WithOpName("c0").WithControlDependencies(recv0), 1);
auto rank = ops::Rank(s.WithOpName("rank"), c0);
auto add0 = ops::Add(s, rank, rank);
auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0,
"receiver");
TF_ASSERT_OK(s.ToGraph(&g));
}
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
Node* c0 = orig_index.at("c0");
PartialTensorShape ps0;
int c0_dims[] = {};
TF_EXPECT_OK(PartialTensorShape::MakePartialShape(c0_dims, 0, &ps0));
std::unordered_map<const Node*, std::vector<PartialTensorShape>> map;
map[c0].push_back(ps0);
ConstantFoldingOptions opts;
opts.shape_map = &map;
bool was_mutated;
TF_EXPECT_OK(
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
EXPECT_TRUE(was_mutated);
std::unordered_map<string, Node*> index = NodeNameIndex(g);
Node* recv0 = index.at("recv0");
Node* send0 = index.at("send0");
ASSERT_EQ(1, send0->num_inputs());
Node* cf0 = *(send0->in_nodes().begin());
ExpectNodeEqual<int>(cf0, {0}, {});
ASSERT_EQ(1, cf0->in_edges().size());
for (const Edge* e : cf0->in_edges()) {
EXPECT_TRUE(e->IsControlEdge());
EXPECT_TRUE(e->src() == recv0) << e->src()->name();
}
}
namespace {
const char kTestMemRegionName[] = "test://test";

View File

@ -1194,7 +1194,8 @@ Status DirectSession::GetOrCreateExecutors(
};
params.node_outputs_cb = node_outputs_callback_;
optimizer.Optimize(lib, options_.env, device, &iter->second);
optimizer.Optimize(lib, options_.env, device, &iter->second,
/*shape_map=*/nullptr);
// EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
if (!options.debug_options.debug_tensor_watch_opts().empty()) {

View File

@ -461,7 +461,7 @@ void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
opts.set_do_function_inlining(true);
opts.set_do_constant_folding(true);
GraphOptimizer optimizer(opts);
optimizer.Optimize(lib, lib->env(), lib->device(), g);
optimizer.Optimize(lib, lib->env(), lib->device(), g, /*shape_map=*/nullptr);
}
Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
@ -470,7 +470,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
std::unique_ptr<Graph> g(new Graph(lib_def_));
CopyGraph(*fbody->graph, g.get());
optimizer_.Optimize(this, env(), device(), &g);
optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
device()->name(), g.get()));

View File

@ -33,8 +33,11 @@ GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) {
GraphOptimizer::~GraphOptimizer() {}
void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
Device* device, std::unique_ptr<Graph>* graph) {
void GraphOptimizer::Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
shape_map) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@ -57,6 +60,7 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
if (opts_.do_constant_folding()) {
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
bool was_mutated;
ConstantFold(cf_opts, runtime, env, device, g, &was_mutated)
.IgnoreError();

View File

@ -30,12 +30,17 @@ class GraphOptimizer {
~GraphOptimizer();
// Applies optimization passes specified in 'opts' to 'graph'.
// Maybe replace *graph with a new graph object.
// 'device' is device on which the 'graph' will execute. It's passed to the
// optimizers so that they can respect constraints if any, that should be
// respected.
void Optimize(FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph);
// Maybe replace *graph with a new graph object. 'device' is device
// on which the 'graph' will execute. It's passed to the optimizers
// so that they can respect constraints if any, that should be
// respected. If shape_map is not null it maps from nodes in graph
// to partially-known shapes of their outputs, and may be used,
// e.g., in the constant folding pass.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
shape_map);
private:
OptimizerOptions opts_;

View File

@ -244,7 +244,8 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
}
};
optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph);
optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
/*shape_map=*/nullptr);
// EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph.
if (!debug_options.debug_tensor_watch_opts().empty()) {

View File

@ -127,7 +127,8 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
// Optimize the graph.
GraphOptimizer optimizer(*optimizer_opts);
optimizer.Optimize(flib.get(), env, devices[0], &graphptr);
optimizer.Optimize(flib.get(), env, devices[0], &graphptr,
/*shape_map=*/nullptr);
graphptr->ToGraphDef(output_graph_def);
return Status::OK();