Add WhileContext class and add plumbing for creating them.
This change introduces WhileContext, which stores information about a while loop and will be used in future changes to generate while loop gradient graphs. Exit nodes in a while loop now have a pointer to their associated WhileContext. This will be used to retrieve the context for a given loop. This change adds an optional parameter to BuildWhileLoop() to create a WhileContext for the while loop (currently this is always true, but gradients will generate while loops without associated contexts). This change also adds a as-yet-unused option to BuildWhileLoop() to return the predicate output. PiperOrigin-RevId: 168562303
This commit is contained in:
parent
a4f6e7c1af
commit
92362d0f05
@ -258,6 +258,7 @@ tf_cc_test(
|
||||
":client_session",
|
||||
":testutil",
|
||||
":while_loop",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
|
@ -172,7 +172,8 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
|
||||
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
|
||||
const CondGraphBuilderFn& cond,
|
||||
const BodyGraphBuilderFn& body, const string& frame_name,
|
||||
OutputList* outputs) {
|
||||
OutputList* outputs, bool create_while_ctx,
|
||||
Output* cond_output) {
|
||||
DCHECK(!inputs.empty());
|
||||
DCHECK(outputs != nullptr);
|
||||
DCHECK(outputs->empty());
|
||||
@ -194,6 +195,7 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
|
||||
|
||||
Output cond_out;
|
||||
TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out));
|
||||
if (cond_output != nullptr) *cond_output = cond_out;
|
||||
|
||||
std::vector<Output> switch_trues(num_loop_vars);
|
||||
std::vector<Output> switch_falses(num_loop_vars);
|
||||
@ -226,7 +228,22 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
|
||||
for (int i = 0; i < num_loop_vars; ++i) {
|
||||
(*outputs)[i] = internal::Exit(scope, switch_falses[i]);
|
||||
}
|
||||
return scope.status();
|
||||
TF_RETURN_IF_ERROR(scope.status());
|
||||
|
||||
if (create_while_ctx) {
|
||||
WhileContext* while_ctx;
|
||||
TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext(
|
||||
frame_name, ToNodes(enter_outputs), ToNodes(*outputs),
|
||||
ToOutputTensor(cond_out), ToOutputTensors(switch_trues),
|
||||
ToOutputTensors(body_outputs), &while_ctx));
|
||||
|
||||
// Set while_ctx for all exit nodes. We currently don't require knowing the
|
||||
// while_ctx for any other nodes.
|
||||
for (int i = 0; i < num_loop_vars; ++i) {
|
||||
(*outputs)[i].node()->set_while_ctx(while_ctx);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
|
@ -48,6 +48,10 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
|
||||
// unique name. This will be used as a prefix for created operations.
|
||||
// * outputs: output param that returns final loop variable outputs in non-error
|
||||
// case. Must be non-null and empty.
|
||||
// * create_while_ctx: if true, a WhileContext is created and populated for this
|
||||
// loop. See core/graph/while_context.h for more details.
|
||||
// * cond_output: if non-null, the output of the predicate is returned. This
|
||||
// will always be a LoopCond node.
|
||||
//
|
||||
// Returns an error if the while loop could not be fully constructed.
|
||||
//
|
||||
@ -56,7 +60,8 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
|
||||
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
|
||||
const CondGraphBuilderFn& cond,
|
||||
const BodyGraphBuilderFn& body, const string& frame_name,
|
||||
OutputList* outputs);
|
||||
OutputList* outputs, bool create_while_ctx = true,
|
||||
Output* cond_output = nullptr);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/while_context.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -38,8 +39,8 @@ class WhileLoopTest : public ::testing::Test {
|
||||
const ops::BodyGraphBuilderFn& body,
|
||||
error::Code error_code = error::OK,
|
||||
const string& error_msg = "") {
|
||||
Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, "test_loop",
|
||||
&outputs_);
|
||||
Status s =
|
||||
ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_);
|
||||
EXPECT_EQ(s.code(), error_code);
|
||||
EXPECT_EQ(s.error_message(), error_msg);
|
||||
}
|
||||
@ -69,8 +70,12 @@ class WhileLoopTest : public ::testing::Test {
|
||||
Scope scope_;
|
||||
std::vector<Output> inputs_;
|
||||
std::vector<Output> outputs_;
|
||||
|
||||
static const char* const kFrameName;
|
||||
};
|
||||
|
||||
const char* const WhileLoopTest::kFrameName = "test_loop";
|
||||
|
||||
Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs,
|
||||
Output* output) {
|
||||
*output = ops::Less(s, inputs[0], 10);
|
||||
@ -87,6 +92,23 @@ TEST_F(WhileLoopTest, Basic) {
|
||||
// Create loop: while (i < 10) i += 1
|
||||
Init(1);
|
||||
CreateLoop(LessThanTenCond, AddOneBody);
|
||||
|
||||
// Verify some output invariants
|
||||
WhileContext* while_ctx;
|
||||
for (int i = 0; i < outputs_.size(); ++i) {
|
||||
Node* node = outputs_[i].node();
|
||||
ASSERT_TRUE(node->IsExit()) << "Output node " << i << ":\n"
|
||||
<< node->DebugString();
|
||||
ASSERT_TRUE(node->while_ctx() != nullptr) << i;
|
||||
if (i == 0) {
|
||||
while_ctx = node->while_ctx();
|
||||
EXPECT_EQ(while_ctx->frame_name(), kFrameName);
|
||||
} else {
|
||||
EXPECT_EQ(node->while_ctx(), while_ctx) << i;
|
||||
}
|
||||
}
|
||||
|
||||
// Run the loop and test we get the expected results
|
||||
Run<int>({1}, {10});
|
||||
Run<int>({11}, {11});
|
||||
}
|
||||
|
@ -50,6 +50,8 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/graph.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
|
@ -241,6 +241,8 @@ file(GLOB_RECURSE tf_core_framework_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/graph.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/util/*.h"
|
||||
"${tensorflow_source_dir}/tensorflow/core/util/*.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc"
|
||||
|
@ -721,6 +721,7 @@ tf_cuda_library(
|
||||
"graph/graph_def_builder.h",
|
||||
"graph/node_builder.h",
|
||||
"graph/validate.h",
|
||||
"graph/while_context.h",
|
||||
"public/session.h",
|
||||
"public/session_options.h",
|
||||
],
|
||||
@ -1581,6 +1582,8 @@ tf_cuda_library(
|
||||
"graph/edgeset.cc",
|
||||
"graph/graph.h",
|
||||
"graph/graph.cc",
|
||||
"graph/while_context.h",
|
||||
"graph/while_context.cc",
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
@ -1717,6 +1720,7 @@ CORE_CPU_BASE_HDRS = [
|
||||
"graph/testlib.h",
|
||||
"graph/types.h",
|
||||
"graph/validate.h",
|
||||
"graph/while_context.h",
|
||||
]
|
||||
|
||||
tf_cuda_library(
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/while_context.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
@ -110,7 +111,8 @@ Node::Node()
|
||||
cost_id_(-1),
|
||||
class_(NC_UNINITIALIZED),
|
||||
props_(nullptr),
|
||||
assigned_device_name_index_(0) {}
|
||||
assigned_device_name_index_(0),
|
||||
while_ctx_(nullptr) {}
|
||||
|
||||
void Node::Initialize(int id, int cost_id,
|
||||
std::shared_ptr<NodeProperties> props) {
|
||||
@ -582,6 +584,27 @@ int Graph::InternDeviceName(const string& device_name) {
|
||||
return index;
|
||||
}
|
||||
|
||||
Status Graph::AddWhileContext(StringPiece frame_name,
|
||||
std::vector<Node*> enter_nodes,
|
||||
std::vector<Node*> exit_nodes,
|
||||
OutputTensor cond_output,
|
||||
std::vector<OutputTensor> body_inputs,
|
||||
std::vector<OutputTensor> body_outputs,
|
||||
WhileContext** result) {
|
||||
auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
|
||||
frame_name.ToString(),
|
||||
WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
|
||||
cond_output, std::move(body_inputs),
|
||||
std::move(body_outputs))));
|
||||
if (!pair.second) {
|
||||
*result = nullptr;
|
||||
return errors::InvalidArgument("WhileContext with frame name '", frame_name,
|
||||
"' already exists");
|
||||
}
|
||||
*result = &pair.first->second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string Edge::DebugString() const {
|
||||
return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
|
||||
src_output_, dst_->name().c_str(), dst_input_);
|
||||
|
@ -60,6 +60,7 @@ class Graph;
|
||||
class GraphDef;
|
||||
class Node;
|
||||
class VersionDef;
|
||||
class WhileContext;
|
||||
|
||||
class NeighborIter; // Declared below
|
||||
class NodeIter; // Declared below
|
||||
@ -182,6 +183,13 @@ class Node {
|
||||
Status input_node(int idx, const Node** n) const;
|
||||
Status input_node(int idx, Node** n) const;
|
||||
|
||||
WhileContext* while_ctx() const { return while_ctx_; }
|
||||
void set_while_ctx(WhileContext* while_ctx) {
|
||||
DCHECK(IsExit());
|
||||
DCHECK(while_ctx_ == nullptr);
|
||||
while_ctx_ = while_ctx;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class Graph;
|
||||
Node();
|
||||
@ -254,6 +262,13 @@ class Node {
|
||||
// field and reclaim that memory.
|
||||
Graph* graph_;
|
||||
|
||||
// Set if this is an exit node of a while loop with an associated
|
||||
// WhileContext. Otherwise null. (This is only set for exit nodes because
|
||||
// they're the first nodes of a loop encountered while creating the gradient
|
||||
// graph. Exit nodes that are part of while loop gradient graphs will not have
|
||||
// this set.)
|
||||
WhileContext* while_ctx_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Node);
|
||||
};
|
||||
|
||||
@ -530,6 +545,16 @@ class Graph {
|
||||
// node->num_outputs()
|
||||
Status IsValidOutputTensor(const Node* node, int idx) const;
|
||||
|
||||
// Create and return a new WhileContext owned by this graph. This is called
|
||||
// when a new while loop is created. `frame_name` must be unique among
|
||||
// WhileContexts in this graph.
|
||||
Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
|
||||
std::vector<Node*> exit_nodes,
|
||||
OutputTensor cond_output,
|
||||
std::vector<OutputTensor> body_inputs,
|
||||
std::vector<OutputTensor> body_outputs,
|
||||
WhileContext** result);
|
||||
|
||||
// TODO(josh11b): uint64 hash() const;
|
||||
|
||||
private:
|
||||
@ -596,6 +621,12 @@ class Graph {
|
||||
// Maps unique device names to indices within device_names_[i].
|
||||
std::unordered_map<string, int> device_names_map_;
|
||||
|
||||
// All the while contexts owned by this graph, keyed by frame name,
|
||||
// corresonding to all the while loops contained in this graph (including
|
||||
// nested loops). The stored contexts are usually accessed via
|
||||
// AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
|
||||
std::map<string, WhileContext> while_ctxs_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
|
||||
};
|
||||
|
||||
|
38
tensorflow/core/graph/while_context.cc
Normal file
38
tensorflow/core/graph/while_context.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/graph/while_context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
WhileContext::WhileContext(StringPiece frame_name,
|
||||
std::vector<Node*> enter_nodes,
|
||||
std::vector<Node*> exit_nodes,
|
||||
OutputTensor cond_output,
|
||||
std::vector<OutputTensor> body_inputs,
|
||||
std::vector<OutputTensor> body_outputs)
|
||||
: frame_name_(frame_name.ToString()),
|
||||
enter_nodes_(std::move(enter_nodes)),
|
||||
exit_nodes_(std::move(exit_nodes)),
|
||||
cond_output_(cond_output),
|
||||
body_inputs_(std::move(body_inputs)),
|
||||
body_outputs_(std::move(body_outputs)) {
|
||||
const size_t num_loop_vars = enter_nodes_.size();
|
||||
DCHECK_EQ(exit_nodes_.size(), num_loop_vars);
|
||||
DCHECK_EQ(body_inputs_.size(), num_loop_vars);
|
||||
DCHECK_EQ(body_outputs_.size(), num_loop_vars);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
76
tensorflow/core/graph/while_context.h
Normal file
76
tensorflow/core/graph/while_context.h
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_GRAPH_WHILE_CONTEXT_H_
|
||||
#define TENSORFLOW_GRAPH_WHILE_CONTEXT_H_
|
||||
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Information about a while loop. Every user-defined while loop has an
|
||||
// associated WhileContext, i.e., there is a WhileContext for every execution
|
||||
// frame. Created with the while loop and used during gradient
|
||||
// construction. Note that the gradient graph of while loop contains while loops
|
||||
// itself, but these do not generate separate WhileContexts.
|
||||
//
|
||||
// TODO(skyewm): this is currently insufficient to handle nested loops and
|
||||
// conditionals (and possibly other requirements). This may change a lot in the
|
||||
// future to support these features.
|
||||
//
|
||||
// TODO(skyewm): de/serialize in MetaGraphDef so imported while loops will be
|
||||
// differentiable. Figure out backwards compatability story.
|
||||
class WhileContext {
|
||||
public:
|
||||
WhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
|
||||
std::vector<Node*> exit_nodes, OutputTensor cond_output,
|
||||
std::vector<OutputTensor> body_inputs,
|
||||
std::vector<OutputTensor> body_outputs);
|
||||
|
||||
const string& frame_name() const { return frame_name_; }
|
||||
const std::vector<Node*>& enter_nodes() const { return enter_nodes_; }
|
||||
const std::vector<Node*>& exit_nodes() const { return exit_nodes_; }
|
||||
const OutputTensor& cond_output() const { return cond_output_; }
|
||||
const std::vector<OutputTensor>& body_inputs() const { return body_inputs_; }
|
||||
const std::vector<OutputTensor>& body_outputs() const {
|
||||
return body_outputs_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Each user-defined while loop defines a new execution frame, which is
|
||||
// uniquely identified by its frame name. Frames are used by the executor to
|
||||
// manage the iterations of a loop. See the FrameState comment in
|
||||
// core/common_runtime/executor.cc for more details.
|
||||
const string frame_name_;
|
||||
|
||||
// The enter nodes defining the input loop variables to the while loop. This
|
||||
// vector defines the order of the loop variables.
|
||||
const std::vector<Node*> enter_nodes_;
|
||||
|
||||
// The exit nodes defining the outputs of the while loop. These are in loop
|
||||
// variable order.
|
||||
const std::vector<Node*> exit_nodes_;
|
||||
|
||||
// The boolean output of the loop predicate.
|
||||
const OutputTensor cond_output_;
|
||||
|
||||
// The inputs and outputs to the loop body.
|
||||
const std::vector<OutputTensor> body_inputs_;
|
||||
const std::vector<OutputTensor> body_outputs_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_GRAPH_GRAPH_H_
|
Loading…
Reference in New Issue
Block a user