Add WhileContext class and add plumbing for creating them.
This change introduces WhileContext, which stores information about a while loop and will be used in future changes to generate while loop gradient graphs. Exit nodes in a while loop now have a pointer to their associated WhileContext. This will be used to retrieve the context for a given loop. This change adds an optional parameter to BuildWhileLoop() to create a WhileContext for the while loop (currently this is always true, but gradients will generate while loops without associated contexts). This change also adds a as-yet-unused option to BuildWhileLoop() to return the predicate output. PiperOrigin-RevId: 168562303
This commit is contained in:
parent
a4f6e7c1af
commit
92362d0f05
@ -258,6 +258,7 @@ tf_cc_test(
|
|||||||
":client_session",
|
":client_session",
|
||||||
":testutil",
|
":testutil",
|
||||||
":while_loop",
|
":while_loop",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
@ -172,7 +172,8 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
|
|||||||
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
|
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
|
||||||
const CondGraphBuilderFn& cond,
|
const CondGraphBuilderFn& cond,
|
||||||
const BodyGraphBuilderFn& body, const string& frame_name,
|
const BodyGraphBuilderFn& body, const string& frame_name,
|
||||||
OutputList* outputs) {
|
OutputList* outputs, bool create_while_ctx,
|
||||||
|
Output* cond_output) {
|
||||||
DCHECK(!inputs.empty());
|
DCHECK(!inputs.empty());
|
||||||
DCHECK(outputs != nullptr);
|
DCHECK(outputs != nullptr);
|
||||||
DCHECK(outputs->empty());
|
DCHECK(outputs->empty());
|
||||||
@ -194,6 +195,7 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
|
|||||||
|
|
||||||
Output cond_out;
|
Output cond_out;
|
||||||
TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out));
|
TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out));
|
||||||
|
if (cond_output != nullptr) *cond_output = cond_out;
|
||||||
|
|
||||||
std::vector<Output> switch_trues(num_loop_vars);
|
std::vector<Output> switch_trues(num_loop_vars);
|
||||||
std::vector<Output> switch_falses(num_loop_vars);
|
std::vector<Output> switch_falses(num_loop_vars);
|
||||||
@ -226,7 +228,22 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
|
|||||||
for (int i = 0; i < num_loop_vars; ++i) {
|
for (int i = 0; i < num_loop_vars; ++i) {
|
||||||
(*outputs)[i] = internal::Exit(scope, switch_falses[i]);
|
(*outputs)[i] = internal::Exit(scope, switch_falses[i]);
|
||||||
}
|
}
|
||||||
return scope.status();
|
TF_RETURN_IF_ERROR(scope.status());
|
||||||
|
|
||||||
|
if (create_while_ctx) {
|
||||||
|
WhileContext* while_ctx;
|
||||||
|
TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext(
|
||||||
|
frame_name, ToNodes(enter_outputs), ToNodes(*outputs),
|
||||||
|
ToOutputTensor(cond_out), ToOutputTensors(switch_trues),
|
||||||
|
ToOutputTensors(body_outputs), &while_ctx));
|
||||||
|
|
||||||
|
// Set while_ctx for all exit nodes. We currently don't require knowing the
|
||||||
|
// while_ctx for any other nodes.
|
||||||
|
for (int i = 0; i < num_loop_vars; ++i) {
|
||||||
|
(*outputs)[i].node()->set_while_ctx(while_ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
@ -48,6 +48,10 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
|
|||||||
// unique name. This will be used as a prefix for created operations.
|
// unique name. This will be used as a prefix for created operations.
|
||||||
// * outputs: output param that returns final loop variable outputs in non-error
|
// * outputs: output param that returns final loop variable outputs in non-error
|
||||||
// case. Must be non-null and empty.
|
// case. Must be non-null and empty.
|
||||||
|
// * create_while_ctx: if true, a WhileContext is created and populated for this
|
||||||
|
// loop. See core/graph/while_context.h for more details.
|
||||||
|
// * cond_output: if non-null, the output of the predicate is returned. This
|
||||||
|
// will always be a LoopCond node.
|
||||||
//
|
//
|
||||||
// Returns an error if the while loop could not be fully constructed.
|
// Returns an error if the while loop could not be fully constructed.
|
||||||
//
|
//
|
||||||
@ -56,7 +60,8 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
|
|||||||
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
|
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
|
||||||
const CondGraphBuilderFn& cond,
|
const CondGraphBuilderFn& cond,
|
||||||
const BodyGraphBuilderFn& body, const string& frame_name,
|
const BodyGraphBuilderFn& body, const string& frame_name,
|
||||||
OutputList* outputs);
|
OutputList* outputs, bool create_while_ctx = true,
|
||||||
|
Output* cond_output = nullptr);
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/client/client_session.h"
|
#include "tensorflow/cc/client/client_session.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/graph/while_context.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
@ -38,8 +39,8 @@ class WhileLoopTest : public ::testing::Test {
|
|||||||
const ops::BodyGraphBuilderFn& body,
|
const ops::BodyGraphBuilderFn& body,
|
||||||
error::Code error_code = error::OK,
|
error::Code error_code = error::OK,
|
||||||
const string& error_msg = "") {
|
const string& error_msg = "") {
|
||||||
Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, "test_loop",
|
Status s =
|
||||||
&outputs_);
|
ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_);
|
||||||
EXPECT_EQ(s.code(), error_code);
|
EXPECT_EQ(s.code(), error_code);
|
||||||
EXPECT_EQ(s.error_message(), error_msg);
|
EXPECT_EQ(s.error_message(), error_msg);
|
||||||
}
|
}
|
||||||
@ -69,8 +70,12 @@ class WhileLoopTest : public ::testing::Test {
|
|||||||
Scope scope_;
|
Scope scope_;
|
||||||
std::vector<Output> inputs_;
|
std::vector<Output> inputs_;
|
||||||
std::vector<Output> outputs_;
|
std::vector<Output> outputs_;
|
||||||
|
|
||||||
|
static const char* const kFrameName;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const char* const WhileLoopTest::kFrameName = "test_loop";
|
||||||
|
|
||||||
Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs,
|
Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs,
|
||||||
Output* output) {
|
Output* output) {
|
||||||
*output = ops::Less(s, inputs[0], 10);
|
*output = ops::Less(s, inputs[0], 10);
|
||||||
@ -87,6 +92,23 @@ TEST_F(WhileLoopTest, Basic) {
|
|||||||
// Create loop: while (i < 10) i += 1
|
// Create loop: while (i < 10) i += 1
|
||||||
Init(1);
|
Init(1);
|
||||||
CreateLoop(LessThanTenCond, AddOneBody);
|
CreateLoop(LessThanTenCond, AddOneBody);
|
||||||
|
|
||||||
|
// Verify some output invariants
|
||||||
|
WhileContext* while_ctx;
|
||||||
|
for (int i = 0; i < outputs_.size(); ++i) {
|
||||||
|
Node* node = outputs_[i].node();
|
||||||
|
ASSERT_TRUE(node->IsExit()) << "Output node " << i << ":\n"
|
||||||
|
<< node->DebugString();
|
||||||
|
ASSERT_TRUE(node->while_ctx() != nullptr) << i;
|
||||||
|
if (i == 0) {
|
||||||
|
while_ctx = node->while_ctx();
|
||||||
|
EXPECT_EQ(while_ctx->frame_name(), kFrameName);
|
||||||
|
} else {
|
||||||
|
EXPECT_EQ(node->while_ctx(), while_ctx) << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the loop and test we get the expected results
|
||||||
Run<int>({1}, {10});
|
Run<int>({1}, {10});
|
||||||
Run<int>({11}, {11});
|
Run<int>({11}, {11});
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,8 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs
|
|||||||
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
|
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
|
"${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/graph/graph.cc"
|
"${tensorflow_source_dir}/tensorflow/core/graph/graph.cc"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.h"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h"
|
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.cc"
|
"${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
"${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||||
|
@ -241,6 +241,8 @@ file(GLOB_RECURSE tf_core_framework_srcs
|
|||||||
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
|
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
|
"${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/graph/graph.cc"
|
"${tensorflow_source_dir}/tensorflow/core/graph/graph.cc"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.h"
|
||||||
|
"${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/util/*.h"
|
"${tensorflow_source_dir}/tensorflow/core/util/*.h"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/util/*.cc"
|
"${tensorflow_source_dir}/tensorflow/core/util/*.cc"
|
||||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc"
|
"${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc"
|
||||||
|
@ -721,6 +721,7 @@ tf_cuda_library(
|
|||||||
"graph/graph_def_builder.h",
|
"graph/graph_def_builder.h",
|
||||||
"graph/node_builder.h",
|
"graph/node_builder.h",
|
||||||
"graph/validate.h",
|
"graph/validate.h",
|
||||||
|
"graph/while_context.h",
|
||||||
"public/session.h",
|
"public/session.h",
|
||||||
"public/session_options.h",
|
"public/session_options.h",
|
||||||
],
|
],
|
||||||
@ -1581,6 +1582,8 @@ tf_cuda_library(
|
|||||||
"graph/edgeset.cc",
|
"graph/edgeset.cc",
|
||||||
"graph/graph.h",
|
"graph/graph.h",
|
||||||
"graph/graph.cc",
|
"graph/graph.cc",
|
||||||
|
"graph/while_context.h",
|
||||||
|
"graph/while_context.cc",
|
||||||
],
|
],
|
||||||
exclude = [
|
exclude = [
|
||||||
"**/*test*",
|
"**/*test*",
|
||||||
@ -1717,6 +1720,7 @@ CORE_CPU_BASE_HDRS = [
|
|||||||
"graph/testlib.h",
|
"graph/testlib.h",
|
||||||
"graph/types.h",
|
"graph/types.h",
|
||||||
"graph/validate.h",
|
"graph/validate.h",
|
||||||
|
"graph/while_context.h",
|
||||||
]
|
]
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
|
#include "tensorflow/core/graph/while_context.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
@ -110,7 +111,8 @@ Node::Node()
|
|||||||
cost_id_(-1),
|
cost_id_(-1),
|
||||||
class_(NC_UNINITIALIZED),
|
class_(NC_UNINITIALIZED),
|
||||||
props_(nullptr),
|
props_(nullptr),
|
||||||
assigned_device_name_index_(0) {}
|
assigned_device_name_index_(0),
|
||||||
|
while_ctx_(nullptr) {}
|
||||||
|
|
||||||
void Node::Initialize(int id, int cost_id,
|
void Node::Initialize(int id, int cost_id,
|
||||||
std::shared_ptr<NodeProperties> props) {
|
std::shared_ptr<NodeProperties> props) {
|
||||||
@ -582,6 +584,27 @@ int Graph::InternDeviceName(const string& device_name) {
|
|||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Graph::AddWhileContext(StringPiece frame_name,
|
||||||
|
std::vector<Node*> enter_nodes,
|
||||||
|
std::vector<Node*> exit_nodes,
|
||||||
|
OutputTensor cond_output,
|
||||||
|
std::vector<OutputTensor> body_inputs,
|
||||||
|
std::vector<OutputTensor> body_outputs,
|
||||||
|
WhileContext** result) {
|
||||||
|
auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
|
||||||
|
frame_name.ToString(),
|
||||||
|
WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
|
||||||
|
cond_output, std::move(body_inputs),
|
||||||
|
std::move(body_outputs))));
|
||||||
|
if (!pair.second) {
|
||||||
|
*result = nullptr;
|
||||||
|
return errors::InvalidArgument("WhileContext with frame name '", frame_name,
|
||||||
|
"' already exists");
|
||||||
|
}
|
||||||
|
*result = &pair.first->second;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
string Edge::DebugString() const {
|
string Edge::DebugString() const {
|
||||||
return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
|
return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
|
||||||
src_output_, dst_->name().c_str(), dst_input_);
|
src_output_, dst_->name().c_str(), dst_input_);
|
||||||
|
@ -60,6 +60,7 @@ class Graph;
|
|||||||
class GraphDef;
|
class GraphDef;
|
||||||
class Node;
|
class Node;
|
||||||
class VersionDef;
|
class VersionDef;
|
||||||
|
class WhileContext;
|
||||||
|
|
||||||
class NeighborIter; // Declared below
|
class NeighborIter; // Declared below
|
||||||
class NodeIter; // Declared below
|
class NodeIter; // Declared below
|
||||||
@ -182,6 +183,13 @@ class Node {
|
|||||||
Status input_node(int idx, const Node** n) const;
|
Status input_node(int idx, const Node** n) const;
|
||||||
Status input_node(int idx, Node** n) const;
|
Status input_node(int idx, Node** n) const;
|
||||||
|
|
||||||
|
WhileContext* while_ctx() const { return while_ctx_; }
|
||||||
|
void set_while_ctx(WhileContext* while_ctx) {
|
||||||
|
DCHECK(IsExit());
|
||||||
|
DCHECK(while_ctx_ == nullptr);
|
||||||
|
while_ctx_ = while_ctx;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Graph;
|
friend class Graph;
|
||||||
Node();
|
Node();
|
||||||
@ -254,6 +262,13 @@ class Node {
|
|||||||
// field and reclaim that memory.
|
// field and reclaim that memory.
|
||||||
Graph* graph_;
|
Graph* graph_;
|
||||||
|
|
||||||
|
// Set if this is an exit node of a while loop with an associated
|
||||||
|
// WhileContext. Otherwise null. (This is only set for exit nodes because
|
||||||
|
// they're the first nodes of a loop encountered while creating the gradient
|
||||||
|
// graph. Exit nodes that are part of while loop gradient graphs will not have
|
||||||
|
// this set.)
|
||||||
|
WhileContext* while_ctx_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Node);
|
TF_DISALLOW_COPY_AND_ASSIGN(Node);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -530,6 +545,16 @@ class Graph {
|
|||||||
// node->num_outputs()
|
// node->num_outputs()
|
||||||
Status IsValidOutputTensor(const Node* node, int idx) const;
|
Status IsValidOutputTensor(const Node* node, int idx) const;
|
||||||
|
|
||||||
|
// Create and return a new WhileContext owned by this graph. This is called
|
||||||
|
// when a new while loop is created. `frame_name` must be unique among
|
||||||
|
// WhileContexts in this graph.
|
||||||
|
Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
|
||||||
|
std::vector<Node*> exit_nodes,
|
||||||
|
OutputTensor cond_output,
|
||||||
|
std::vector<OutputTensor> body_inputs,
|
||||||
|
std::vector<OutputTensor> body_outputs,
|
||||||
|
WhileContext** result);
|
||||||
|
|
||||||
// TODO(josh11b): uint64 hash() const;
|
// TODO(josh11b): uint64 hash() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -596,6 +621,12 @@ class Graph {
|
|||||||
// Maps unique device names to indices within device_names_[i].
|
// Maps unique device names to indices within device_names_[i].
|
||||||
std::unordered_map<string, int> device_names_map_;
|
std::unordered_map<string, int> device_names_map_;
|
||||||
|
|
||||||
|
// All the while contexts owned by this graph, keyed by frame name,
|
||||||
|
// corresonding to all the while loops contained in this graph (including
|
||||||
|
// nested loops). The stored contexts are usually accessed via
|
||||||
|
// AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
|
||||||
|
std::map<string, WhileContext> while_ctxs_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
|
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
38
tensorflow/core/graph/while_context.cc
Normal file
38
tensorflow/core/graph/while_context.cc
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/graph/while_context.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
WhileContext::WhileContext(StringPiece frame_name,
|
||||||
|
std::vector<Node*> enter_nodes,
|
||||||
|
std::vector<Node*> exit_nodes,
|
||||||
|
OutputTensor cond_output,
|
||||||
|
std::vector<OutputTensor> body_inputs,
|
||||||
|
std::vector<OutputTensor> body_outputs)
|
||||||
|
: frame_name_(frame_name.ToString()),
|
||||||
|
enter_nodes_(std::move(enter_nodes)),
|
||||||
|
exit_nodes_(std::move(exit_nodes)),
|
||||||
|
cond_output_(cond_output),
|
||||||
|
body_inputs_(std::move(body_inputs)),
|
||||||
|
body_outputs_(std::move(body_outputs)) {
|
||||||
|
const size_t num_loop_vars = enter_nodes_.size();
|
||||||
|
DCHECK_EQ(exit_nodes_.size(), num_loop_vars);
|
||||||
|
DCHECK_EQ(body_inputs_.size(), num_loop_vars);
|
||||||
|
DCHECK_EQ(body_outputs_.size(), num_loop_vars);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
76
tensorflow/core/graph/while_context.h
Normal file
76
tensorflow/core/graph/while_context.h
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_GRAPH_WHILE_CONTEXT_H_
|
||||||
|
#define TENSORFLOW_GRAPH_WHILE_CONTEXT_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Information about a while loop. Every user-defined while loop has an
|
||||||
|
// associated WhileContext, i.e., there is a WhileContext for every execution
|
||||||
|
// frame. Created with the while loop and used during gradient
|
||||||
|
// construction. Note that the gradient graph of while loop contains while loops
|
||||||
|
// itself, but these do not generate separate WhileContexts.
|
||||||
|
//
|
||||||
|
// TODO(skyewm): this is currently insufficient to handle nested loops and
|
||||||
|
// conditionals (and possibly other requirements). This may change a lot in the
|
||||||
|
// future to support these features.
|
||||||
|
//
|
||||||
|
// TODO(skyewm): de/serialize in MetaGraphDef so imported while loops will be
|
||||||
|
// differentiable. Figure out backwards compatability story.
|
||||||
|
class WhileContext {
|
||||||
|
public:
|
||||||
|
WhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
|
||||||
|
std::vector<Node*> exit_nodes, OutputTensor cond_output,
|
||||||
|
std::vector<OutputTensor> body_inputs,
|
||||||
|
std::vector<OutputTensor> body_outputs);
|
||||||
|
|
||||||
|
const string& frame_name() const { return frame_name_; }
|
||||||
|
const std::vector<Node*>& enter_nodes() const { return enter_nodes_; }
|
||||||
|
const std::vector<Node*>& exit_nodes() const { return exit_nodes_; }
|
||||||
|
const OutputTensor& cond_output() const { return cond_output_; }
|
||||||
|
const std::vector<OutputTensor>& body_inputs() const { return body_inputs_; }
|
||||||
|
const std::vector<OutputTensor>& body_outputs() const {
|
||||||
|
return body_outputs_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Each user-defined while loop defines a new execution frame, which is
|
||||||
|
// uniquely identified by its frame name. Frames are used by the executor to
|
||||||
|
// manage the iterations of a loop. See the FrameState comment in
|
||||||
|
// core/common_runtime/executor.cc for more details.
|
||||||
|
const string frame_name_;
|
||||||
|
|
||||||
|
// The enter nodes defining the input loop variables to the while loop. This
|
||||||
|
// vector defines the order of the loop variables.
|
||||||
|
const std::vector<Node*> enter_nodes_;
|
||||||
|
|
||||||
|
// The exit nodes defining the outputs of the while loop. These are in loop
|
||||||
|
// variable order.
|
||||||
|
const std::vector<Node*> exit_nodes_;
|
||||||
|
|
||||||
|
// The boolean output of the loop predicate.
|
||||||
|
const OutputTensor cond_output_;
|
||||||
|
|
||||||
|
// The inputs and outputs to the loop body.
|
||||||
|
const std::vector<OutputTensor> body_inputs_;
|
||||||
|
const std::vector<OutputTensor> body_outputs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_GRAPH_GRAPH_H_
|
Loading…
Reference in New Issue
Block a user