199 lines
8.0 KiB
C++
199 lines
8.0 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/cc/framework/while_gradients.h"
|
|
|
|
#include "tensorflow/cc/framework/gradients.h"
|
|
#include "tensorflow/cc/framework/scope_internal.h"
|
|
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
|
#include "tensorflow/cc/ops/standard_ops.h"
|
|
#include "tensorflow/cc/ops/while_loop.h"
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
|
|
using ops::BodyGraphBuilderFn;
|
|
using ops::BuildWhileLoop;
|
|
using ops::CondGraphBuilderFn;
|
|
|
|
Output ToOutput(OutputTensor output_tensor) {
|
|
return Output(const_cast<Node*>(output_tensor.node), output_tensor.index);
|
|
}
|
|
|
|
std::vector<Output> ToOutputVector(
|
|
const std::vector<OutputTensor>& output_tensors) {
|
|
size_t n = output_tensors.size();
|
|
std::vector<Output> result;
|
|
result.reserve(n);
|
|
for (int i = 0; i < n; ++i) result.push_back(ToOutput(output_tensors[i]));
|
|
return result;
|
|
}
|
|
|
|
// The backprop loop counter and main backprop loop run in their own execution
|
|
// frame (conceptually, the main forward loop and forward loop counter run
|
|
// together in a frame, then the backprop loop counter and backprop loop run
|
|
// together in a different frame). This returns the frame name to use for the
|
|
// backprop while loops.
|
|
// TODO(skyewm): make sure this is unique among existing frame names
|
|
string BackPropFrameName(const string& forward_frame_name) {
|
|
return strings::StrCat(forward_frame_name, "_backprop");
|
|
}
|
|
|
|
// Creates a loop that counts the number of iterations performed by the
|
|
// while loop associated with `while_ctx`. The returned output yields the
|
|
// iteration count.
|
|
Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope,
|
|
Output* count) {
|
|
// Create while loop:
|
|
// i = 0
|
|
// while forward loop predicate is true:
|
|
// ++i
|
|
|
|
Output zero = ops::Const(scope, 0, {});
|
|
|
|
// Condition function that returns condition output from original while loop.
|
|
CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope,
|
|
const std::vector<Output>& inputs,
|
|
Output* output) {
|
|
*output = ToOutput(while_ctx->cond_output());
|
|
return Status::OK();
|
|
};
|
|
|
|
// Body function that adds one to input.
|
|
BodyGraphBuilderFn body_fn = [](const Scope& scope,
|
|
const std::vector<Output>& inputs,
|
|
std::vector<Output>* outputs) {
|
|
DCHECK_EQ(inputs.size(), 1);
|
|
outputs->emplace_back(ops::Add(scope, inputs[0], 1));
|
|
return scope.status();
|
|
};
|
|
|
|
// Note that this loop runs in the same execution frame as the forward loop.
|
|
std::vector<Output> outputs;
|
|
TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn,
|
|
while_ctx->frame_name(), &outputs,
|
|
/* create_while_ctx */ false));
|
|
*count = outputs[0];
|
|
return Status::OK();
|
|
}
|
|
|
|
// Creates a loop that executes `loop_count` times. The returned output is the
|
|
// boolean predicate indicating if the loop is still executing. This is used to
|
|
// drive the gradient computation for the while loop associated with
|
|
// `while_ctx`.
|
|
Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count,
|
|
const Scope& scope,
|
|
Output* backprop_execution_pred) {
|
|
// Create while loop:
|
|
// n = loop_count
|
|
// while n > 0:
|
|
// --n
|
|
|
|
// Condition function that returns input > 0.
|
|
CondGraphBuilderFn cond_fn = [](const Scope& scope,
|
|
const std::vector<Output>& inputs,
|
|
Output* output) {
|
|
DCHECK_EQ(inputs.size(), 1);
|
|
*output = ops::Greater(scope, inputs[0], 0);
|
|
return scope.status();
|
|
};
|
|
|
|
// Body function that subtracts one from input.
|
|
BodyGraphBuilderFn body_fn = [](const Scope& scope,
|
|
const std::vector<Output>& inputs,
|
|
std::vector<Output>* outputs) {
|
|
DCHECK_EQ(inputs.size(), 1);
|
|
outputs->emplace_back(ops::Subtract(scope, inputs[0], 1));
|
|
return scope.status();
|
|
};
|
|
|
|
string frame_name = BackPropFrameName(while_ctx->frame_name());
|
|
std::vector<Output> outputs;
|
|
TF_RETURN_IF_ERROR(BuildWhileLoop(
|
|
scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs,
|
|
/* create_while_ctx */ false, backprop_execution_pred));
|
|
return Status::OK();
|
|
}
|
|
|
|
// Creates the main backprop loop that computes the gradient of the loop
|
|
// associated with `while_ctx`. `grad_inputs` are the partial derivatives
|
|
// w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is
|
|
// the predicate to use for the backprop loop (see AddBackPropLoopCounter()).
|
|
// The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are
|
|
// returned in `grad_outputs`.
|
|
Status AddWhileGradientLoop(WhileContext* while_ctx,
|
|
const std::vector<Output>& grad_inputs,
|
|
const Output& backprop_execution_pred,
|
|
const Scope& parent_scope,
|
|
std::vector<Output>* grad_outputs) {
|
|
DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size());
|
|
DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size());
|
|
|
|
Scope scope = parent_scope.NewSubScope("while");
|
|
|
|
// Create while loop:
|
|
// while backprop_execution_pred:
|
|
// forward loop body gradient
|
|
|
|
// Condition function that returns 'backprop_execution_pred'.
|
|
CondGraphBuilderFn cond_fn = [backprop_execution_pred](
|
|
const Scope& scope,
|
|
const std::vector<Output>& inputs,
|
|
Output* output) {
|
|
*output = backprop_execution_pred;
|
|
return Status::OK();
|
|
};
|
|
|
|
// Body function that builds while body gradient subgraph.
|
|
BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
|
|
const std::vector<Output>& inputs,
|
|
std::vector<Output>* outputs) {
|
|
std::vector<Output> body_outputs =
|
|
ToOutputVector(while_ctx->body_outputs());
|
|
std::vector<Output> body_inputs = ToOutputVector(while_ctx->body_inputs());
|
|
return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs,
|
|
outputs);
|
|
};
|
|
|
|
string frame_name = BackPropFrameName(while_ctx->frame_name());
|
|
TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn,
|
|
frame_name, grad_outputs,
|
|
/* create_while_ctx */ false));
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
|
|
const std::vector<Output>& grad_inputs,
|
|
std::vector<Output>* grad_outputs) {
|
|
Output forward_loop_count;
|
|
TF_RETURN_IF_ERROR(AddForwardLoopCounter(
|
|
while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count));
|
|
|
|
// TODO(skyewm): can we combine the backprop loop counter and main gradient
|
|
// loop into a single loop? The original Python code doesn't combine the
|
|
// loops, but I'm not sure why.
|
|
Output backprop_counter_cond;
|
|
TF_RETURN_IF_ERROR(AddBackPropLoopCounter(
|
|
while_ctx, forward_loop_count, scope.NewSubScope("BackPropLoopCounter"),
|
|
&backprop_counter_cond));
|
|
|
|
return AddWhileGradientLoop(while_ctx, grad_inputs, backprop_counter_cond,
|
|
scope, grad_outputs);
|
|
}
|
|
|
|
} // namespace tensorflow
|