diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 028de608803..0d2c9f2d195 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -258,6 +258,7 @@ tf_cc_test( ":client_session", ":testutil", ":while_loop", + "//tensorflow/core:core_cpu", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc index e3e39da85e6..e0251efb2a4 100644 --- a/tensorflow/cc/ops/while_loop.cc +++ b/tensorflow/cc/ops/while_loop.cc @@ -172,7 +172,8 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, const CondGraphBuilderFn& cond, const BodyGraphBuilderFn& body, const string& frame_name, - OutputList* outputs) { + OutputList* outputs, bool create_while_ctx, + Output* cond_output) { DCHECK(!inputs.empty()); DCHECK(outputs != nullptr); DCHECK(outputs->empty()); @@ -194,6 +195,7 @@ Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, Output cond_out; TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out)); + if (cond_output != nullptr) *cond_output = cond_out; std::vector switch_trues(num_loop_vars); std::vector switch_falses(num_loop_vars); @@ -226,7 +228,22 @@ Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, for (int i = 0; i < num_loop_vars; ++i) { (*outputs)[i] = internal::Exit(scope, switch_falses[i]); } - return scope.status(); + TF_RETURN_IF_ERROR(scope.status()); + + if (create_while_ctx) { + WhileContext* while_ctx; + TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext( + frame_name, ToNodes(enter_outputs), ToNodes(*outputs), + ToOutputTensor(cond_out), ToOutputTensors(switch_trues), + ToOutputTensors(body_outputs), &while_ctx)); + + // Set while_ctx for all exit nodes. We currently don't require knowing the + // while_ctx for any other nodes. + for (int i = 0; i < num_loop_vars; ++i) { + (*outputs)[i].node()->set_while_ctx(while_ctx); + } + } + return Status::OK(); } } // namespace ops diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h index 253d5d8935c..82181516d6d 100644 --- a/tensorflow/cc/ops/while_loop.h +++ b/tensorflow/cc/ops/while_loop.h @@ -48,6 +48,10 @@ typedef std::function& inputs, // unique name. This will be used as a prefix for created operations. // * outputs: output param that returns final loop variable outputs in non-error // case. Must be non-null and empty. +// * create_while_ctx: if true, a WhileContext is created and populated for this +// loop. See core/graph/while_context.h for more details. +// * cond_output: if non-null, the output of the predicate is returned. This +// will always be a LoopCond node. // // Returns an error if the while loop could not be fully constructed. // @@ -56,7 +60,8 @@ typedef std::function& inputs, Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, const CondGraphBuilderFn& cond, const BodyGraphBuilderFn& body, const string& frame_name, - OutputList* outputs); + OutputList* outputs, bool create_while_ctx = true, + Output* cond_output = nullptr); } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc index 77028b5c41d..e3f6523c190 100644 --- a/tensorflow/cc/ops/while_loop_test.cc +++ b/tensorflow/cc/ops/while_loop_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -38,8 +39,8 @@ class WhileLoopTest : public ::testing::Test { const ops::BodyGraphBuilderFn& body, error::Code error_code = error::OK, const string& error_msg = "") { - Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, "test_loop", - &outputs_); + Status s = + ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_); EXPECT_EQ(s.code(), error_code); EXPECT_EQ(s.error_message(), error_msg); } @@ -69,8 +70,12 @@ class WhileLoopTest : public ::testing::Test { Scope scope_; std::vector inputs_; std::vector outputs_; + + static const char* const kFrameName; }; +const char* const WhileLoopTest::kFrameName = "test_loop"; + Status LessThanTenCond(const Scope& s, const std::vector& inputs, Output* output) { *output = ops::Less(s, inputs[0], 10); @@ -87,6 +92,23 @@ TEST_F(WhileLoopTest, Basic) { // Create loop: while (i < 10) i += 1 Init(1); CreateLoop(LessThanTenCond, AddOneBody); + + // Verify some output invariants + WhileContext* while_ctx; + for (int i = 0; i < outputs_.size(); ++i) { + Node* node = outputs_[i].node(); + ASSERT_TRUE(node->IsExit()) << "Output node " << i << ":\n" + << node->DebugString(); + ASSERT_TRUE(node->while_ctx() != nullptr) << i; + if (i == 0) { + while_ctx = node->while_ctx(); + EXPECT_EQ(while_ctx->frame_name(), kFrameName); + } else { + EXPECT_EQ(node->while_ctx(), while_ctx) << i; + } + } + + // Run the loop and test we get the expected results Run({1}, {10}); Run({11}, {11}); } diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index c76f124892c..5c01ca382fb 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -50,6 +50,8 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/while_context.h" + "${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h" "${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index f7470d3bce8..53d64133102 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -241,6 +241,8 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/while_context.h" + "${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc" "${tensorflow_source_dir}/tensorflow/core/util/*.h" "${tensorflow_source_dir}/tensorflow/core/util/*.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 52fe59a03e1..87cb212ad0f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -721,6 +721,7 @@ tf_cuda_library( "graph/graph_def_builder.h", "graph/node_builder.h", "graph/validate.h", + "graph/while_context.h", "public/session.h", "public/session_options.h", ], @@ -1581,6 +1582,8 @@ tf_cuda_library( "graph/edgeset.cc", "graph/graph.h", "graph/graph.cc", + "graph/while_context.h", + "graph/while_context.cc", ], exclude = [ "**/*test*", @@ -1717,6 +1720,7 @@ CORE_CPU_BASE_HDRS = [ "graph/testlib.h", "graph/types.h", "graph/validate.h", + "graph/while_context.h", ] tf_cuda_library( diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index a274c799704..599f802ee06 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -110,7 +111,8 @@ Node::Node() cost_id_(-1), class_(NC_UNINITIALIZED), props_(nullptr), - assigned_device_name_index_(0) {} + assigned_device_name_index_(0), + while_ctx_(nullptr) {} void Node::Initialize(int id, int cost_id, std::shared_ptr props) { @@ -582,6 +584,27 @@ int Graph::InternDeviceName(const string& device_name) { return index; } +Status Graph::AddWhileContext(StringPiece frame_name, + std::vector enter_nodes, + std::vector exit_nodes, + OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs, + WhileContext** result) { + auto pair = while_ctxs_.insert(std::pair( + frame_name.ToString(), + WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes), + cond_output, std::move(body_inputs), + std::move(body_outputs)))); + if (!pair.second) { + *result = nullptr; + return errors::InvalidArgument("WhileContext with frame name '", frame_name, + "' already exists"); + } + *result = &pair.first->second; + return Status::OK(); +} + string Edge::DebugString() const { return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(), src_output_, dst_->name().c_str(), dst_input_); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index f825675392a..3aee6f21df6 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -60,6 +60,7 @@ class Graph; class GraphDef; class Node; class VersionDef; +class WhileContext; class NeighborIter; // Declared below class NodeIter; // Declared below @@ -182,6 +183,13 @@ class Node { Status input_node(int idx, const Node** n) const; Status input_node(int idx, Node** n) const; + WhileContext* while_ctx() const { return while_ctx_; } + void set_while_ctx(WhileContext* while_ctx) { + DCHECK(IsExit()); + DCHECK(while_ctx_ == nullptr); + while_ctx_ = while_ctx; + } + private: friend class Graph; Node(); @@ -254,6 +262,13 @@ class Node { // field and reclaim that memory. Graph* graph_; + // Set if this is an exit node of a while loop with an associated + // WhileContext. Otherwise null. (This is only set for exit nodes because + // they're the first nodes of a loop encountered while creating the gradient + // graph. Exit nodes that are part of while loop gradient graphs will not have + // this set.) + WhileContext* while_ctx_; + TF_DISALLOW_COPY_AND_ASSIGN(Node); }; @@ -530,6 +545,16 @@ class Graph { // node->num_outputs() Status IsValidOutputTensor(const Node* node, int idx) const; + // Create and return a new WhileContext owned by this graph. This is called + // when a new while loop is created. `frame_name` must be unique among + // WhileContexts in this graph. + Status AddWhileContext(StringPiece frame_name, std::vector enter_nodes, + std::vector exit_nodes, + OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs, + WhileContext** result); + // TODO(josh11b): uint64 hash() const; private: @@ -596,6 +621,12 @@ class Graph { // Maps unique device names to indices within device_names_[i]. std::unordered_map device_names_map_; + // All the while contexts owned by this graph, keyed by frame name, + // corresonding to all the while loops contained in this graph (including + // nested loops). The stored contexts are usually accessed via + // AddWhileContext() or Node::while_ctx(), but this manages the lifetime. + std::map while_ctxs_; + TF_DISALLOW_COPY_AND_ASSIGN(Graph); }; diff --git a/tensorflow/core/graph/while_context.cc b/tensorflow/core/graph/while_context.cc new file mode 100644 index 00000000000..10a2b67f378 --- /dev/null +++ b/tensorflow/core/graph/while_context.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/graph/while_context.h" + +namespace tensorflow { + +WhileContext::WhileContext(StringPiece frame_name, + std::vector enter_nodes, + std::vector exit_nodes, + OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs) + : frame_name_(frame_name.ToString()), + enter_nodes_(std::move(enter_nodes)), + exit_nodes_(std::move(exit_nodes)), + cond_output_(cond_output), + body_inputs_(std::move(body_inputs)), + body_outputs_(std::move(body_outputs)) { + const size_t num_loop_vars = enter_nodes_.size(); + DCHECK_EQ(exit_nodes_.size(), num_loop_vars); + DCHECK_EQ(body_inputs_.size(), num_loop_vars); + DCHECK_EQ(body_outputs_.size(), num_loop_vars); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/while_context.h b/tensorflow/core/graph/while_context.h new file mode 100644 index 00000000000..5944e368979 --- /dev/null +++ b/tensorflow/core/graph/while_context.h @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPH_WHILE_CONTEXT_H_ +#define TENSORFLOW_GRAPH_WHILE_CONTEXT_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Information about a while loop. Every user-defined while loop has an +// associated WhileContext, i.e., there is a WhileContext for every execution +// frame. Created with the while loop and used during gradient +// construction. Note that the gradient graph of while loop contains while loops +// itself, but these do not generate separate WhileContexts. +// +// TODO(skyewm): this is currently insufficient to handle nested loops and +// conditionals (and possibly other requirements). This may change a lot in the +// future to support these features. +// +// TODO(skyewm): de/serialize in MetaGraphDef so imported while loops will be +// differentiable. Figure out backwards compatability story. +class WhileContext { + public: + WhileContext(StringPiece frame_name, std::vector enter_nodes, + std::vector exit_nodes, OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs); + + const string& frame_name() const { return frame_name_; } + const std::vector& enter_nodes() const { return enter_nodes_; } + const std::vector& exit_nodes() const { return exit_nodes_; } + const OutputTensor& cond_output() const { return cond_output_; } + const std::vector& body_inputs() const { return body_inputs_; } + const std::vector& body_outputs() const { + return body_outputs_; + } + + private: + // Each user-defined while loop defines a new execution frame, which is + // uniquely identified by its frame name. Frames are used by the executor to + // manage the iterations of a loop. See the FrameState comment in + // core/common_runtime/executor.cc for more details. + const string frame_name_; + + // The enter nodes defining the input loop variables to the while loop. This + // vector defines the order of the loop variables. + const std::vector enter_nodes_; + + // The exit nodes defining the outputs of the while loop. These are in loop + // variable order. + const std::vector exit_nodes_; + + // The boolean output of the loop predicate. + const OutputTensor cond_output_; + + // The inputs and outputs to the loop body. + const std::vector body_inputs_; + const std::vector body_outputs_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_H_