Basic while loop gradient functionality in C++
This change introduces the basic framework to create the gradient graph of a while loop using the C++ API. This supports building the gradient graph as long as the body function of the while loop contains no ops whose gradient function requires a stack. In other words, it doesn't support gradient functions that use the input values to the op (e.g. add will work, but multiply will not). It also doesn't support nested while loops, and doesn't detect all error cases. PiperOrigin-RevId: 170243281
This commit is contained in:
parent
545e3572f7
commit
301b14c240
tensorflow
c
cc
contrib/cmake
core
@ -73,6 +73,11 @@ class CApiWhileLoopTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
void Run(std::initializer_list<int> input_values) {
|
||||
Run(outputs_, input_values);
|
||||
}
|
||||
|
||||
void Run(const std::vector<TF_Output>& run_outputs,
|
||||
std::initializer_list<int> input_values) {
|
||||
DCHECK_EQ(inputs_.size(), input_values.size());
|
||||
std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
|
||||
int i = 0;
|
||||
@ -82,7 +87,7 @@ class CApiWhileLoopTest : public ::testing::Test {
|
||||
}
|
||||
csession_.reset(new CSession(graph_, s_));
|
||||
csession_->SetInputs(inputs);
|
||||
csession_->SetOutputs(outputs_);
|
||||
csession_->SetOutputs(run_outputs);
|
||||
csession_->Run(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
}
|
||||
@ -402,4 +407,36 @@ TEST_F(CApiWhileLoopTest, BadTypes) {
|
||||
TF_AbortWhile(params_.get());
|
||||
}
|
||||
|
||||
// This is a basic test to make sure the C++ gradient code can handle while
|
||||
// loops created by the C API (which calls the C++ API under the hood). There
|
||||
// are more while loop gradient tests in cc/framework/while_gradients_test.cc.
|
||||
TEST_F(CApiWhileLoopTest, Gradients) {
|
||||
Init(1);
|
||||
|
||||
// Create loop: while (i < 10) i += 1
|
||||
TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_);
|
||||
TF_Operation* less_than =
|
||||
LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_);
|
||||
DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
params_->cond_output = {less_than, 0};
|
||||
|
||||
TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
|
||||
TF_Operation* add =
|
||||
Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
params_->body_outputs[0] = {add, 0};
|
||||
|
||||
ExpectOK();
|
||||
|
||||
// Create backprop graph
|
||||
TF_Output grad_output;
|
||||
TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1,
|
||||
nullptr, s_, &grad_output);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
// Run gradient
|
||||
Run({grad_output}, {0});
|
||||
ExpectOutputValue(0, 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -19,13 +19,20 @@ load(
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
srcs = ["framework/gradients.cc"],
|
||||
srcs = [
|
||||
"framework/gradients.cc",
|
||||
"framework/while_gradients.cc",
|
||||
"framework/while_gradients.h",
|
||||
],
|
||||
hdrs = ["framework/gradients.h"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":cc_ops_internal",
|
||||
":grad_op_registry",
|
||||
":ops",
|
||||
":scope",
|
||||
":scope_internal",
|
||||
":while_loop",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -52,6 +59,28 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "framework_while_gradients_test",
|
||||
size = "small",
|
||||
srcs = ["framework/while_gradients_test.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":client_session",
|
||||
":grad_op_registry",
|
||||
":grad_ops",
|
||||
":gradients",
|
||||
":testutil",
|
||||
":while_loop",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradient_checker",
|
||||
srcs = ["framework/gradient_checker.cc"],
|
||||
|
@ -16,8 +16,9 @@ limitations under the License.
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/framework/while_gradients.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -25,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/while_context.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
@ -82,6 +84,13 @@ class SymbolicGradientBuilder {
|
||||
// from outputs_. Keyed by node id.
|
||||
std::vector<bool> GetReachableNodes();
|
||||
|
||||
// Creates the gradient subgraph for a while loop (or just stores
|
||||
// `summed_grads` if not all incoming gradients are available yet). All exit
|
||||
// nodes (which are the first nodes of a loop encountered in the backwards
|
||||
// pass) are passed to this function rather than processed normally.
|
||||
// `summed_grads` is the sum of `exit_node`s gradients.
|
||||
Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads);
|
||||
|
||||
const Scope& scope_;
|
||||
const ops::GradOpRegistry* registry_;
|
||||
const std::vector<Output>& outputs_;
|
||||
@ -89,8 +98,7 @@ class SymbolicGradientBuilder {
|
||||
const std::vector<Output>& grad_inputs_;
|
||||
std::vector<Output>* grad_outputs_;
|
||||
|
||||
// A vector of output endpoints which represents backpropagated
|
||||
// gradients
|
||||
// A vector of output endpoints which represents backpropagated gradients
|
||||
typedef std::vector<Output> BackpropedGradients;
|
||||
|
||||
// backprops_ is a map from a node output to its accumulated
|
||||
@ -117,6 +125,12 @@ class SymbolicGradientBuilder {
|
||||
// frontier. Maps from Output -> index into `grad_outputs_`.
|
||||
std::unordered_map<Output, int, OutputHash, OutputEq> input_nodes_;
|
||||
|
||||
// For each while loop in the graph, collects the summed gradients for each of
|
||||
// the loop's exit nodes. Note that unlike backprops_, this map contains the
|
||||
// output of SumGradients(), not the input (i.e. each exit node may have
|
||||
// multiple incoming gradients, but we only store the combined Output here).
|
||||
std::map<WhileContext*, std::map<Node*, Output>> while_backprops_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder);
|
||||
};
|
||||
|
||||
@ -150,6 +164,7 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad,
|
||||
std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
|
||||
std::vector<bool> reachable_nodes(scope_.graph()->num_node_ids(), false);
|
||||
std::deque<Node*> queue;
|
||||
std::vector<bool> visited(scope_.graph()->num_node_ids(), false);
|
||||
for (const Output& out : outputs_) {
|
||||
if (!reachable_nodes[out.node()->id()]) {
|
||||
queue.push_back(out.node());
|
||||
@ -162,8 +177,10 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
|
||||
queue.pop_front();
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
if (visited[e->src()->id()]) continue;
|
||||
queue.push_back(e->src());
|
||||
reachable_nodes[e->src()->id()] = true;
|
||||
visited[e->src()->id()] = true;
|
||||
}
|
||||
}
|
||||
return reachable_nodes;
|
||||
@ -304,6 +321,53 @@ Status SymbolicGradientBuilder::CallGradFunction(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node,
|
||||
const Output& summed_grads) {
|
||||
// TOOD(skyewm): detect second-order gradient and return bad status
|
||||
// TODO(skyewm): handle (or at least detect) nested while loops
|
||||
|
||||
// TODO(skyewm): handle NoGradient in while loop
|
||||
if (summed_grads == NoGradient()) {
|
||||
return errors::Unimplemented(
|
||||
"Missing gradient into while loop not yet implemented");
|
||||
}
|
||||
|
||||
DCHECK(exit_node->IsExit());
|
||||
WhileContext* while_ctx = exit_node->while_ctx();
|
||||
DCHECK(while_ctx != nullptr);
|
||||
|
||||
// Record 'summed_grads' as the backprop input associated with 'exit_node'
|
||||
std::map<Node*, Output>& backprops = while_backprops_[while_ctx];
|
||||
DCHECK(backprops.find(exit_node) == backprops.end());
|
||||
backprops[exit_node] = summed_grads;
|
||||
|
||||
// Wait until we have all exit nodes' backprops collected before processing
|
||||
// the while loop.
|
||||
// TODO(skyewm): what if not all the exit nodes are reachable?
|
||||
if (backprops.size() < while_ctx->exit_nodes().size()) return Status::OK();
|
||||
|
||||
// We've seen all the exit nodes for this loop and have collected all the
|
||||
// backprops. Create the gradient graph for the while loop.
|
||||
Scope while_scope =
|
||||
scope_.NewSubScope(strings::StrCat(while_ctx->frame_name(), "_grad"));
|
||||
std::vector<Output> dy;
|
||||
for (Node* n : while_ctx->exit_nodes()) dy.push_back(backprops[n]);
|
||||
std::vector<Output> dx;
|
||||
TF_RETURN_IF_ERROR(AddWhileLoopGradient(while_ctx, while_scope, dy, &dx));
|
||||
|
||||
// Backprop along the in edges to the while loop (i.e. the inputs to the enter
|
||||
// nodes)
|
||||
DCHECK_EQ(dx.size(), while_ctx->enter_nodes().size());
|
||||
for (int i = 0; i < dx.size(); ++i) {
|
||||
Node* enter_node = while_ctx->enter_nodes()[i];
|
||||
for (const Edge* e : enter_node->in_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
TF_RETURN_IF_ERROR(BackpropAlongEdge(dx[i], {e->src(), e->src_output()}));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SymbolicGradientBuilder::AddGradients() {
|
||||
// Initialize backprops.
|
||||
TF_RETURN_IF_ERROR(Initialize());
|
||||
@ -346,6 +410,18 @@ Status SymbolicGradientBuilder::AddGradients() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Special case: if we find an exit node, process the associated while loop.
|
||||
// Note that ProcessWhileLoop() calls BackpropAlongEdge() if necessary
|
||||
// (which updates ready_), and we skip all the regular processing below
|
||||
// after calling it.
|
||||
if (n->IsExit()) {
|
||||
DCHECK_EQ(dy.size(), 1);
|
||||
TF_RETURN_IF_ERROR(ProcessWhileLoop(n, dy[0]));
|
||||
continue;
|
||||
}
|
||||
// All loop-specific control flow ops should have been handled above
|
||||
DCHECK(!n->IsEnter() && !n->IsNextIteration()) << n->DebugString();
|
||||
|
||||
const size_t num_no_grad = no_grad_dy_indices.size();
|
||||
if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) {
|
||||
// No grad defined for this op, or all outputs returned 'NoGradient':
|
||||
|
197
tensorflow/cc/framework/while_gradients.cc
Normal file
197
tensorflow/cc/framework/while_gradients.cc
Normal file
@ -0,0 +1,197 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/while_gradients.h"
|
||||
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/cc/ops/while_loop.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using ops::BodyGraphBuilderFn;
|
||||
using ops::BuildWhileLoop;
|
||||
using ops::CondGraphBuilderFn;
|
||||
|
||||
Output ToOutput(OutputTensor output_tensor) {
|
||||
return Output(const_cast<Node*>(output_tensor.node), output_tensor.index);
|
||||
}
|
||||
|
||||
std::vector<Output> ToOutputVector(
|
||||
const std::vector<OutputTensor>& output_tensors) {
|
||||
size_t n = output_tensors.size();
|
||||
std::vector<Output> result(n);
|
||||
for (int i = 0; i < n; ++i) result[i] = ToOutput(output_tensors[i]);
|
||||
return result;
|
||||
}
|
||||
|
||||
// The backprop loop counter and main backprop loop run in their own execution
|
||||
// frame (conceptually, the main forward loop and forward loop counter run
|
||||
// together in a frame, then the backprop loop counter and backprop loop run
|
||||
// together in a different frame). This returns the frame name to use for the
|
||||
// backprop while loops.
|
||||
// TODO(skyewm): make sure this is unique among existing frame names
|
||||
string BackPropFrameName(const string& forward_frame_name) {
|
||||
return strings::StrCat(forward_frame_name, "_backprop");
|
||||
}
|
||||
|
||||
// Creates a loop that counts the number of iterations performed by the
|
||||
// while loop associated with `while_ctx`. The returned output yields the
|
||||
// iteration count.
|
||||
Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope,
|
||||
Output* count) {
|
||||
// Create while loop:
|
||||
// i = 0
|
||||
// while forward loop predicate is true:
|
||||
// ++i
|
||||
|
||||
Output zero = ops::Const(scope, 0, {});
|
||||
|
||||
// Condition function that returns condition output from original while loop.
|
||||
CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope,
|
||||
const std::vector<Output>& inputs,
|
||||
Output* output) {
|
||||
*output = ToOutput(while_ctx->cond_output());
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// Body function that adds one to input.
|
||||
BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
|
||||
const std::vector<Output>& inputs,
|
||||
std::vector<Output>* outputs) {
|
||||
DCHECK_EQ(inputs.size(), 1);
|
||||
outputs->emplace_back(ops::Add(scope, inputs[0], 1));
|
||||
return scope.status();
|
||||
};
|
||||
|
||||
// Note that this loop runs in the same execution frame as the forward loop.
|
||||
std::vector<Output> outputs;
|
||||
TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn,
|
||||
while_ctx->frame_name(), &outputs,
|
||||
/* create_while_ctx */ false));
|
||||
*count = outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a loop that executes `loop_count` times. The returned output is the
|
||||
// boolean predicate indicating if the loop is still executing. This is used to
|
||||
// drive the gradient computation for the while loop associated with
|
||||
// `while_ctx`.
|
||||
Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count,
|
||||
const Scope& scope,
|
||||
Output* backprop_execution_pred) {
|
||||
// Create while loop:
|
||||
// n = loop_count
|
||||
// while n > 0:
|
||||
// --n
|
||||
|
||||
// Condition function that returns input > 0.
|
||||
CondGraphBuilderFn cond_fn = [](const Scope& scope,
|
||||
const std::vector<Output>& inputs,
|
||||
Output* output) {
|
||||
DCHECK_EQ(inputs.size(), 1);
|
||||
*output = ops::Greater(scope, inputs[0], 0);
|
||||
return scope.status();
|
||||
};
|
||||
|
||||
// Body function that subtracts one from input.
|
||||
BodyGraphBuilderFn body_fn = [](const Scope& scope,
|
||||
const std::vector<Output>& inputs,
|
||||
std::vector<Output>* outputs) {
|
||||
DCHECK_EQ(inputs.size(), 1);
|
||||
outputs->emplace_back(ops::Subtract(scope, inputs[0], 1));
|
||||
return scope.status();
|
||||
};
|
||||
|
||||
string frame_name = BackPropFrameName(while_ctx->frame_name());
|
||||
std::vector<Output> outputs; // unused
|
||||
TF_RETURN_IF_ERROR(BuildWhileLoop(
|
||||
scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs,
|
||||
/* create_while_ctx */ false, backprop_execution_pred));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates the main backprop loop that computes the gradient of the loop
|
||||
// associated with `while_ctx`. `grad_inputs` are the partial derivatives
|
||||
// w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is
|
||||
// the predicate to use for the backprop loop (see AddBackPropLoopCounter()).
|
||||
// The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are
|
||||
// returned in `grad_outputs`.
|
||||
Status AddWhileGradientLoop(WhileContext* while_ctx,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
const Output& backprop_execution_pred,
|
||||
const Scope& parent_scope,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size());
|
||||
DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size());
|
||||
|
||||
Scope scope = parent_scope.NewSubScope("while");
|
||||
|
||||
// Create while loop:
|
||||
// while backprop_execution_pred:
|
||||
// forward loop body gradient
|
||||
|
||||
// Condition function that returns 'backprop_execution_pred'.
|
||||
CondGraphBuilderFn cond_fn = [backprop_execution_pred](
|
||||
const Scope& scope,
|
||||
const std::vector<Output>& inputs,
|
||||
Output* output) {
|
||||
*output = backprop_execution_pred;
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// Body function that builds while body gradient subgraph.
|
||||
BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
|
||||
const std::vector<Output>& inputs,
|
||||
std::vector<Output>* outputs) {
|
||||
std::vector<Output> body_outputs =
|
||||
ToOutputVector(while_ctx->body_outputs());
|
||||
std::vector<Output> body_inputs = ToOutputVector(while_ctx->body_inputs());
|
||||
return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs,
|
||||
outputs);
|
||||
};
|
||||
|
||||
string frame_name = BackPropFrameName(while_ctx->frame_name());
|
||||
TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn,
|
||||
frame_name, grad_outputs,
|
||||
/* create_while_ctx */ false));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
Output forward_loop_count;
|
||||
TF_RETURN_IF_ERROR(AddForwardLoopCounter(
|
||||
while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count));
|
||||
|
||||
// TODO(skyewm): can we combine the backprop loop counter and main gradient
|
||||
// loop into a single loop? The original Python code doesn't combine the
|
||||
// loops, but I'm not sure why.
|
||||
Output backprop_counter_cond;
|
||||
TF_RETURN_IF_ERROR(AddBackPropLoopCounter(
|
||||
while_ctx, forward_loop_count, scope.NewSubScope("BackPropLoopCounter"),
|
||||
&backprop_counter_cond));
|
||||
|
||||
return AddWhileGradientLoop(while_ctx, grad_inputs, backprop_counter_cond,
|
||||
scope, grad_outputs);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
40
tensorflow/cc/framework/while_gradients.h
Normal file
40
tensorflow/cc/framework/while_gradients.h
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/graph/while_context.h"
|
||||
|
||||
// Utility functions for constructing while loop gradients
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Adds the gradient computation for the while loop associated with
|
||||
// `while_ctx`. `grad_inputs` are the partial derivatives w.r.t. the loop
|
||||
// outputs, i.e. the exit nodes. The partial derivatives w.r.t. the loop
|
||||
// inputs, i.e. the input loop vars, are returned in `grad_outputs`.
|
||||
// `grad_inputs` and `grad_outputs` are both in loop-variable order, as defined
|
||||
// by the original inputs to BuildWhileLoop().
|
||||
// TODO(skyewm): maybe comment on NoGradient once it's supported
|
||||
Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
|
233
tensorflow/cc/framework/while_gradients_test.cc
Normal file
233
tensorflow/cc/framework/while_gradients_test.cc
Normal file
@ -0,0 +1,233 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/framework/testutil.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/cc/ops/while_loop.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
class WhileGradientsTest : public ::testing::Test {
|
||||
protected:
|
||||
WhileGradientsTest() : scope_(Scope::NewRootScope()) {}
|
||||
|
||||
void Init(int num_inputs, DataType dtype = DT_INT32) {
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
inputs_.push_back(ops::Placeholder(scope_, dtype));
|
||||
}
|
||||
}
|
||||
|
||||
void CreateLoop(const ops::CondGraphBuilderFn& cond,
|
||||
const ops::BodyGraphBuilderFn& body,
|
||||
const std::vector<Output>* inputs = nullptr) {
|
||||
if (inputs == nullptr) inputs = &inputs_;
|
||||
TF_ASSERT_OK(ops::BuildWhileLoop(scope_, *inputs, cond, body, "test_loop",
|
||||
&outputs_));
|
||||
}
|
||||
|
||||
void CreateBackprop() {
|
||||
TF_ASSERT_OK(
|
||||
AddSymbolicGradients(scope_, outputs_, inputs_, &grad_outputs_));
|
||||
ASSERT_EQ(grad_outputs_.size(), inputs_.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Run(const std::vector<Input::Initializer>& input_values,
|
||||
const std::vector<T>& expected_grad_values) {
|
||||
Run<T>(ClientSession(scope_), input_values, expected_grad_values);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Run(const ClientSession& session,
|
||||
const std::vector<Input::Initializer>& input_values,
|
||||
const std::vector<T>& expected_grad_values,
|
||||
const RunOptions& run_options = RunOptions(),
|
||||
RunMetadata* run_metadata = nullptr) {
|
||||
DCHECK_EQ(input_values.size(), inputs_.size());
|
||||
ClientSession::FeedType feeds;
|
||||
for (int i = 0; i < inputs_.size(); ++i) {
|
||||
feeds.emplace(inputs_[i], input_values[i]);
|
||||
}
|
||||
|
||||
std::vector<Operation> run_outputs;
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(run_options, feeds, grad_outputs_, run_outputs,
|
||||
&out_tensors, run_metadata));
|
||||
ASSERT_EQ(out_tensors.size(), grad_outputs_.size());
|
||||
|
||||
DCHECK_EQ(expected_grad_values.size(), out_tensors.size());
|
||||
for (int i = 0; i < out_tensors.size(); ++i) {
|
||||
test::ExpectTensorEqual<T>(
|
||||
out_tensors[i], test::AsTensor<T>({expected_grad_values[i]}, {}));
|
||||
}
|
||||
}
|
||||
|
||||
Scope scope_;
|
||||
std::vector<Output> inputs_;
|
||||
std::vector<Output> outputs_;
|
||||
std::vector<Output> grad_outputs_;
|
||||
};
|
||||
|
||||
TEST_F(WhileGradientsTest, Basic) {
|
||||
// Create loop: while (i < 10) i += 1
|
||||
Init(1);
|
||||
CreateLoop(
|
||||
[](const Scope& s, const std::vector<Output>& inputs, Output* output) {
|
||||
*output = ops::Less(s, inputs[0], 10);
|
||||
return s.status();
|
||||
},
|
||||
[](const Scope& s, const std::vector<Output>& inputs,
|
||||
std::vector<Output>* outputs) {
|
||||
// Use AddN, rather than Add, because the gradient function doesn't
|
||||
// depend on the input shapes, and thus we do not need to store
|
||||
// intermediate values in a stack.
|
||||
outputs->push_back(ops::AddN(s, {inputs[0], 1}));
|
||||
return s.status();
|
||||
});
|
||||
CreateBackprop();
|
||||
|
||||
Run<int>({1}, {1});
|
||||
Run<int>({11}, {1});
|
||||
}
|
||||
|
||||
TEST_F(WhileGradientsTest, MultipleLoopVars) {
|
||||
// Create loop: while (i < 10) i += j; j += 1; k = k
|
||||
Init(3);
|
||||
CreateLoop(
|
||||
[](const Scope& s, const std::vector<Output>& inputs, Output* output) {
|
||||
*output = ops::Less(s, inputs[0], 10);
|
||||
return s.status();
|
||||
},
|
||||
[](const Scope& s, const std::vector<Output>& inputs,
|
||||
std::vector<Output>* outputs) {
|
||||
outputs->push_back(ops::AddN(s, {inputs[0], inputs[1]}));
|
||||
outputs->push_back(ops::AddN(s, {inputs[1], 1}));
|
||||
outputs->push_back(inputs[2]);
|
||||
return s.status();
|
||||
});
|
||||
CreateBackprop();
|
||||
|
||||
// The following execution traces illustrate why we expect dF/dj to be 5:
|
||||
//
|
||||
// i j k
|
||||
// ---------
|
||||
// 0 1 2 <-- initial values
|
||||
// 1 2 2
|
||||
// 3 3 2
|
||||
// 6 4 2
|
||||
// 10 5 2 <-- while output values
|
||||
// outputs sum = 17
|
||||
//
|
||||
// i j k
|
||||
// ---------
|
||||
// 0 2 2 <-- initial values (add 1 to j)
|
||||
// 2 3 2
|
||||
// 5 4 2
|
||||
// 9 5 2
|
||||
// 14 6 2 <-- while output values
|
||||
// outputs sum = 22
|
||||
//
|
||||
// Calculate the "slope" between j=1 and j=2:
|
||||
// 22 - 17 = 5 => dF/dj = 5
|
||||
Run<int>({0, 1, 2}, {1, 5, 1});
|
||||
|
||||
Run<int>({1, 1, 0}, {1, 5, 1});
|
||||
Run<int>({0, 0, 0}, {1, 6, 1});
|
||||
}
|
||||
|
||||
TEST_F(WhileGradientsTest, Chaining) {
|
||||
Init(2, DT_DOUBLE);
|
||||
|
||||
// Multiply each input by 2 before passing to while loop to make sure chaining
|
||||
// works properly
|
||||
std::vector<Output> loop_inputs = {ops::Multiply(scope_, inputs_[0], 2.0),
|
||||
ops::Multiply(scope_, inputs_[1], 2.0)};
|
||||
|
||||
// Create loop: while (i > 0 && j > 0) i -= 1
|
||||
CreateLoop(
|
||||
[](const Scope& s, const std::vector<Output>& inputs, Output* output) {
|
||||
*output = ops::LogicalAnd(s, ops::Greater(s, inputs[0], 0.0),
|
||||
ops::Greater(s, inputs[1], 0.0));
|
||||
return s.status();
|
||||
},
|
||||
[](const Scope& s, const std::vector<Output>& inputs,
|
||||
std::vector<Output>* outputs) {
|
||||
outputs->push_back(ops::AddN(s, {inputs[0], -1.0}));
|
||||
outputs->push_back(inputs[1]);
|
||||
return s.status();
|
||||
},
|
||||
&loop_inputs);
|
||||
|
||||
// Take negative of first output to make sure chaining works properly
|
||||
outputs_[0] = ops::Neg(scope_, outputs_[0]);
|
||||
|
||||
CreateBackprop();
|
||||
|
||||
Run<double>({1.0, 1.0}, {-2.0, 2.0});
|
||||
Run<double>({0.0, 0.0}, {-2.0, 2.0});
|
||||
}
|
||||
|
||||
TEST_F(WhileGradientsTest, MultipleDevices) {
|
||||
// Make sure loop is created on cpu0
|
||||
scope_ = scope_.WithDevice("/cpu:0");
|
||||
|
||||
// Create loop: while (i < 10) i += j
|
||||
Init(2);
|
||||
CreateLoop(
|
||||
[](const Scope& s, const std::vector<Output>& inputs, Output* output) {
|
||||
*output = ops::Less(s, inputs[0], 10);
|
||||
return s.status();
|
||||
},
|
||||
[](const Scope& s, const std::vector<Output>& inputs,
|
||||
std::vector<Output>* outputs) {
|
||||
// Place body on cpu1
|
||||
Scope cpu1_scope = s.WithDevice("/cpu:1");
|
||||
outputs->push_back(ops::AddN(cpu1_scope, {inputs[0], inputs[1]}));
|
||||
outputs->push_back(inputs[1]);
|
||||
return cpu1_scope.status();
|
||||
});
|
||||
|
||||
// Build gradient graph on cpu1
|
||||
Scope cpu1_scope = scope_.WithDevice("/cpu:1");
|
||||
TF_ASSERT_OK(
|
||||
AddSymbolicGradients(cpu1_scope, outputs_, inputs_, &grad_outputs_));
|
||||
ASSERT_EQ(grad_outputs_.size(), inputs_.size());
|
||||
|
||||
// Run with two CPU devices and output partition graphs
|
||||
SessionOptions session_options;
|
||||
(*session_options.config.mutable_device_count())["CPU"] = 2;
|
||||
RunOptions run_options;
|
||||
run_options.set_output_partition_graphs(true);
|
||||
RunMetadata run_metadata;
|
||||
Run<int>(ClientSession(scope_, session_options), {0, 1}, {1, 11}, run_options,
|
||||
&run_metadata);
|
||||
|
||||
// Check that at least one node ran on each device
|
||||
ASSERT_EQ(run_metadata.partition_graphs().size(), 2);
|
||||
for (const GraphDef& partition_graph : run_metadata.partition_graphs()) {
|
||||
EXPECT_GE(partition_graph.node().size(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -49,7 +49,12 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
|
||||
// * outputs: output param that returns final loop variable outputs in non-error
|
||||
// case. Must be non-null and empty.
|
||||
// * create_while_ctx: if true, a WhileContext is created and populated for this
|
||||
// loop. See core/graph/while_context.h for more details.
|
||||
// loop. See core/graph/while_context.h for more details on
|
||||
// WhileContexts. This is set to false for loops used as part of gradient
|
||||
// computations, since they're part of the gradient for a loop in the
|
||||
// forward-pass.
|
||||
// TODO(skyewm): revisit this. Should we create WhileContexts for all loops,
|
||||
// even if we don't need them?
|
||||
// * cond_output: if non-null, the output of the predicate is returned. This
|
||||
// will always be a LoopCond node.
|
||||
//
|
||||
|
@ -135,6 +135,8 @@ set(tf_cc_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/gradient_checker.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/gradients.h"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/gradients.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/while_gradients.h"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/while_gradients.cc"
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE tf_cc_test_srcs
|
||||
|
@ -2613,6 +2613,7 @@ tf_cc_tests(
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/cc:while_loop",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/cc/ops/random_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/while_loop.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -72,10 +73,13 @@ void Partition(const GraphDef& graph_def,
|
||||
GraphConstructorOptions opts;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g));
|
||||
|
||||
// Assigns devices to each node. Uses 1st letter of the node name as
|
||||
// the device index.
|
||||
// Assigns devices to each node. Uses 1st letter of the node name as the
|
||||
// device index if no device is specified.
|
||||
for (Node* node : g.nodes()) {
|
||||
node->set_assigned_device_name(DeviceName(node));
|
||||
string device_name = !node->requested_device().empty()
|
||||
? node->requested_device()
|
||||
: DeviceName(node);
|
||||
node->set_assigned_device_name(device_name);
|
||||
}
|
||||
|
||||
PartitionOptions popts;
|
||||
@ -368,7 +372,7 @@ TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
|
||||
ExpectMatchB();
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, CrossDeviceLoop) {
|
||||
TEST_F(GraphPartitionTest, CrossDeviceLoopSimple) {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
auto a1 = BoolInput(in_.WithOpName("A1"));
|
||||
auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("A2"), a1, "foo");
|
||||
@ -382,7 +386,7 @@ TEST_F(GraphPartitionTest, CrossDeviceLoop) {
|
||||
CheckLoopConstruction(ToGraphDef());
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, CrossDeviceLoop1) {
|
||||
TEST_F(GraphPartitionTest, CrossDeviceLoopSimple1) {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
auto a1 = BoolInput(in_.WithOpName("A1"));
|
||||
auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("B2"), a1, "foo");
|
||||
@ -407,6 +411,29 @@ TEST_F(GraphPartitionTest, CrossDeviceLoop1) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, CrossDeviceLoopFull) {
|
||||
Scope cpu0 = in_.WithDevice("/job:a/replica:0/task:0/cpu:0");
|
||||
auto p1 = ops::Placeholder(cpu0, DT_INT32);
|
||||
auto p2 = ops::Placeholder(cpu0, DT_INT32);
|
||||
OutputList outputs;
|
||||
// while i1 < 10: i1 += i2
|
||||
TF_ASSERT_OK(ops::BuildWhileLoop(
|
||||
cpu0, {p1, p2},
|
||||
[](const Scope& s, const std::vector<Output>& inputs, Output* output) {
|
||||
*output = ops::Less(s, inputs[0], 10);
|
||||
return s.status();
|
||||
},
|
||||
[](const Scope& s, const std::vector<Output>& inputs,
|
||||
std::vector<Output>* outputs) {
|
||||
Scope cpu1 = s.WithDevice("/job:a/replica:0/task:0/cpu:1");
|
||||
outputs->push_back(ops::AddN(cpu1, {inputs[0], inputs[1]}));
|
||||
outputs->push_back(inputs[1]);
|
||||
return s.status();
|
||||
},
|
||||
"test_loop", &outputs));
|
||||
CheckLoopConstruction(ToGraphDef());
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, PartitionIncompleteGraph) {
|
||||
NodeDef ndef;
|
||||
Graph g(OpRegistry::Global());
|
||||
|
Loading…
Reference in New Issue
Block a user