Add WhileContext class and add plumbing for creating them.

This change introduces WhileContext, which stores information about a
while loop and will be used in future changes to generate while loop
gradient graphs. Exit nodes in a while loop now have a pointer to
their associated WhileContext. This will be used to retrieve the
context for a given loop.

This change adds an optional parameter to BuildWhileLoop() to create a
WhileContext for the while loop (currently this is always true, but
gradients will generate while loops without associated contexts). This
change also adds a as-yet-unused option to BuildWhileLoop() to return
the predicate output.

PiperOrigin-RevId: 168562303
This commit is contained in:
Skye Wanderman-Milne 2017-09-13 10:49:45 -07:00 committed by TensorFlower Gardener
parent a4f6e7c1af
commit 92362d0f05
11 changed files with 227 additions and 6 deletions

View File

@ -258,6 +258,7 @@ tf_cc_test(
":client_session",
":testutil",
":while_loop",
"//tensorflow/core:core_cpu",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",

View File

@ -172,7 +172,8 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
const CondGraphBuilderFn& cond,
const BodyGraphBuilderFn& body, const string& frame_name,
OutputList* outputs) {
OutputList* outputs, bool create_while_ctx,
Output* cond_output) {
DCHECK(!inputs.empty());
DCHECK(outputs != nullptr);
DCHECK(outputs->empty());
@ -194,6 +195,7 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
Output cond_out;
TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out));
if (cond_output != nullptr) *cond_output = cond_out;
std::vector<Output> switch_trues(num_loop_vars);
std::vector<Output> switch_falses(num_loop_vars);
@ -226,7 +228,22 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
for (int i = 0; i < num_loop_vars; ++i) {
(*outputs)[i] = internal::Exit(scope, switch_falses[i]);
}
return scope.status();
TF_RETURN_IF_ERROR(scope.status());
if (create_while_ctx) {
WhileContext* while_ctx;
TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext(
frame_name, ToNodes(enter_outputs), ToNodes(*outputs),
ToOutputTensor(cond_out), ToOutputTensors(switch_trues),
ToOutputTensors(body_outputs), &while_ctx));
// Set while_ctx for all exit nodes. We currently don't require knowing the
// while_ctx for any other nodes.
for (int i = 0; i < num_loop_vars; ++i) {
(*outputs)[i].node()->set_while_ctx(while_ctx);
}
}
return Status::OK();
}
} // namespace ops

View File

@ -48,6 +48,10 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
// unique name. This will be used as a prefix for created operations.
// * outputs: output param that returns final loop variable outputs in non-error
// case. Must be non-null and empty.
// * create_while_ctx: if true, a WhileContext is created and populated for this
// loop. See core/graph/while_context.h for more details.
// * cond_output: if non-null, the output of the predicate is returned. This
// will always be a LoopCond node.
//
// Returns an error if the while loop could not be fully constructed.
//
@ -56,7 +60,8 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
const CondGraphBuilderFn& cond,
const BodyGraphBuilderFn& body, const string& frame_name,
OutputList* outputs);
OutputList* outputs, bool create_while_ctx = true,
Output* cond_output = nullptr);
} // namespace ops
} // namespace tensorflow

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -38,8 +39,8 @@ class WhileLoopTest : public ::testing::Test {
const ops::BodyGraphBuilderFn& body,
error::Code error_code = error::OK,
const string& error_msg = "") {
Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, "test_loop",
&outputs_);
Status s =
ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_);
EXPECT_EQ(s.code(), error_code);
EXPECT_EQ(s.error_message(), error_msg);
}
@ -69,8 +70,12 @@ class WhileLoopTest : public ::testing::Test {
Scope scope_;
std::vector<Output> inputs_;
std::vector<Output> outputs_;
static const char* const kFrameName;
};
const char* const WhileLoopTest::kFrameName = "test_loop";
Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs,
Output* output) {
*output = ops::Less(s, inputs[0], 10);
@ -87,6 +92,23 @@ TEST_F(WhileLoopTest, Basic) {
// Create loop: while (i < 10) i += 1
Init(1);
CreateLoop(LessThanTenCond, AddOneBody);
// Verify some output invariants
WhileContext* while_ctx;
for (int i = 0; i < outputs_.size(); ++i) {
Node* node = outputs_[i].node();
ASSERT_TRUE(node->IsExit()) << "Output node " << i << ":\n"
<< node->DebugString();
ASSERT_TRUE(node->while_ctx() != nullptr) << i;
if (i == 0) {
while_ctx = node->while_ctx();
EXPECT_EQ(while_ctx->frame_name(), kFrameName);
} else {
EXPECT_EQ(node->while_ctx(), while_ctx) << i;
}
}
// Run the loop and test we get the expected results
Run<int>({1}, {10});
Run<int>({11}, {11});
}

View File

@ -50,6 +50,8 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
"${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
"${tensorflow_source_dir}/tensorflow/core/graph/graph.cc"
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.h"
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc"
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h"
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.cc"
"${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"

View File

@ -241,6 +241,8 @@ file(GLOB_RECURSE tf_core_framework_srcs
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
"${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
"${tensorflow_source_dir}/tensorflow/core/graph/graph.cc"
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.h"
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc"
"${tensorflow_source_dir}/tensorflow/core/util/*.h"
"${tensorflow_source_dir}/tensorflow/core/util/*.cc"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc"

View File

@ -721,6 +721,7 @@ tf_cuda_library(
"graph/graph_def_builder.h",
"graph/node_builder.h",
"graph/validate.h",
"graph/while_context.h",
"public/session.h",
"public/session_options.h",
],
@ -1581,6 +1582,8 @@ tf_cuda_library(
"graph/edgeset.cc",
"graph/graph.h",
"graph/graph.cc",
"graph/while_context.h",
"graph/while_context.cc",
],
exclude = [
"**/*test*",
@ -1717,6 +1720,7 @@ CORE_CPU_BASE_HDRS = [
"graph/testlib.h",
"graph/types.h",
"graph/validate.h",
"graph/while_context.h",
]
tf_cuda_library(

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -110,7 +111,8 @@ Node::Node()
cost_id_(-1),
class_(NC_UNINITIALIZED),
props_(nullptr),
assigned_device_name_index_(0) {}
assigned_device_name_index_(0),
while_ctx_(nullptr) {}
void Node::Initialize(int id, int cost_id,
std::shared_ptr<NodeProperties> props) {
@ -582,6 +584,27 @@ int Graph::InternDeviceName(const string& device_name) {
return index;
}
Status Graph::AddWhileContext(StringPiece frame_name,
std::vector<Node*> enter_nodes,
std::vector<Node*> exit_nodes,
OutputTensor cond_output,
std::vector<OutputTensor> body_inputs,
std::vector<OutputTensor> body_outputs,
WhileContext** result) {
auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
frame_name.ToString(),
WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
cond_output, std::move(body_inputs),
std::move(body_outputs))));
if (!pair.second) {
*result = nullptr;
return errors::InvalidArgument("WhileContext with frame name '", frame_name,
"' already exists");
}
*result = &pair.first->second;
return Status::OK();
}
string Edge::DebugString() const {
return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
src_output_, dst_->name().c_str(), dst_input_);

View File

@ -60,6 +60,7 @@ class Graph;
class GraphDef;
class Node;
class VersionDef;
class WhileContext;
class NeighborIter; // Declared below
class NodeIter; // Declared below
@ -182,6 +183,13 @@ class Node {
Status input_node(int idx, const Node** n) const;
Status input_node(int idx, Node** n) const;
WhileContext* while_ctx() const { return while_ctx_; }
void set_while_ctx(WhileContext* while_ctx) {
DCHECK(IsExit());
DCHECK(while_ctx_ == nullptr);
while_ctx_ = while_ctx;
}
private:
friend class Graph;
Node();
@ -254,6 +262,13 @@ class Node {
// field and reclaim that memory.
Graph* graph_;
// Set if this is an exit node of a while loop with an associated
// WhileContext. Otherwise null. (This is only set for exit nodes because
// they're the first nodes of a loop encountered while creating the gradient
// graph. Exit nodes that are part of while loop gradient graphs will not have
// this set.)
WhileContext* while_ctx_;
TF_DISALLOW_COPY_AND_ASSIGN(Node);
};
@ -530,6 +545,16 @@ class Graph {
// node->num_outputs()
Status IsValidOutputTensor(const Node* node, int idx) const;
// Create and return a new WhileContext owned by this graph. This is called
// when a new while loop is created. `frame_name` must be unique among
// WhileContexts in this graph.
Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
std::vector<Node*> exit_nodes,
OutputTensor cond_output,
std::vector<OutputTensor> body_inputs,
std::vector<OutputTensor> body_outputs,
WhileContext** result);
// TODO(josh11b): uint64 hash() const;
private:
@ -596,6 +621,12 @@ class Graph {
// Maps unique device names to indices within device_names_[i].
std::unordered_map<string, int> device_names_map_;
// All the while contexts owned by this graph, keyed by frame name,
// corresonding to all the while loops contained in this graph (including
// nested loops). The stored contexts are usually accessed via
// AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
std::map<string, WhileContext> while_ctxs_;
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
};

View File

@ -0,0 +1,38 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/graph/while_context.h"
namespace tensorflow {
WhileContext::WhileContext(StringPiece frame_name,
std::vector<Node*> enter_nodes,
std::vector<Node*> exit_nodes,
OutputTensor cond_output,
std::vector<OutputTensor> body_inputs,
std::vector<OutputTensor> body_outputs)
: frame_name_(frame_name.ToString()),
enter_nodes_(std::move(enter_nodes)),
exit_nodes_(std::move(exit_nodes)),
cond_output_(cond_output),
body_inputs_(std::move(body_inputs)),
body_outputs_(std::move(body_outputs)) {
const size_t num_loop_vars = enter_nodes_.size();
DCHECK_EQ(exit_nodes_.size(), num_loop_vars);
DCHECK_EQ(body_inputs_.size(), num_loop_vars);
DCHECK_EQ(body_outputs_.size(), num_loop_vars);
}
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_GRAPH_WHILE_CONTEXT_H_
#define TENSORFLOW_GRAPH_WHILE_CONTEXT_H_
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// Information about a while loop. Every user-defined while loop has an
// associated WhileContext, i.e., there is a WhileContext for every execution
// frame. Created with the while loop and used during gradient
// construction. Note that the gradient graph of while loop contains while loops
// itself, but these do not generate separate WhileContexts.
//
// TODO(skyewm): this is currently insufficient to handle nested loops and
// conditionals (and possibly other requirements). This may change a lot in the
// future to support these features.
//
// TODO(skyewm): de/serialize in MetaGraphDef so imported while loops will be
// differentiable. Figure out backwards compatability story.
class WhileContext {
public:
WhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
std::vector<Node*> exit_nodes, OutputTensor cond_output,
std::vector<OutputTensor> body_inputs,
std::vector<OutputTensor> body_outputs);
const string& frame_name() const { return frame_name_; }
const std::vector<Node*>& enter_nodes() const { return enter_nodes_; }
const std::vector<Node*>& exit_nodes() const { return exit_nodes_; }
const OutputTensor& cond_output() const { return cond_output_; }
const std::vector<OutputTensor>& body_inputs() const { return body_inputs_; }
const std::vector<OutputTensor>& body_outputs() const {
return body_outputs_;
}
private:
// Each user-defined while loop defines a new execution frame, which is
// uniquely identified by its frame name. Frames are used by the executor to
// manage the iterations of a loop. See the FrameState comment in
// core/common_runtime/executor.cc for more details.
const string frame_name_;
// The enter nodes defining the input loop variables to the while loop. This
// vector defines the order of the loop variables.
const std::vector<Node*> enter_nodes_;
// The exit nodes defining the outputs of the while loop. These are in loop
// variable order.
const std::vector<Node*> exit_nodes_;
// The boolean output of the loop predicate.
const OutputTensor cond_output_;
// The inputs and outputs to the loop body.
const std::vector<OutputTensor> body_inputs_;
const std::vector<OutputTensor> body_outputs_;
};
} // namespace tensorflow
#endif // TENSORFLOW_GRAPH_GRAPH_H_