Removes void*s from the tape gradient code, replacing with templates.
PiperOrigin-RevId: 175155685
This commit is contained in:
parent
18d5c3e4cf
commit
efab2e1d91
tensorflow
@ -106,7 +106,6 @@ tf_cc_test(
|
||||
|
||||
cc_library(
|
||||
name = "tape",
|
||||
srcs = ["tape.cc"],
|
||||
hdrs = ["tape.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
|
@ -1,410 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/c/eager/tape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
bool GradientTape::ShouldRecord(gtl::ArraySlice<int64> tensor_ids) {
|
||||
for (int64 i : tensor_ids) {
|
||||
if (tensor_tape_.find(i) != tensor_tape_.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void GradientTape::Watch(int64 tensor_id) {
|
||||
tensor_tape_.emplace(tensor_id, -1);
|
||||
}
|
||||
|
||||
void GradientTape::RecordOperation(
|
||||
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id, void* backward_function,
|
||||
const std::function<void()>& backward_function_deleter) {
|
||||
if (!ShouldRecord(input_tensor_id)) {
|
||||
backward_function_deleter();
|
||||
return;
|
||||
}
|
||||
std::vector<int64> ids;
|
||||
ids.reserve(input_tensor_id.size());
|
||||
for (int64 i : input_tensor_id) {
|
||||
tensor_usage_[i]++;
|
||||
ids.push_back(i);
|
||||
}
|
||||
const int64 op_id = next_op_id_++;
|
||||
std::vector<TapeTensor> tensors;
|
||||
tensors.reserve(output_tensors.size());
|
||||
for (const TapeTensor& o : output_tensors) {
|
||||
// Note: the tensor can have already been watched and hence be in the tape,
|
||||
// so we cannot check that we're inserting it here.
|
||||
tensor_tape_[o.id] = op_id;
|
||||
tensor_usage_[o.id] = 1;
|
||||
tensors.push_back(o);
|
||||
}
|
||||
op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function,
|
||||
backward_function_deleter};
|
||||
}
|
||||
|
||||
void GradientTape::DeleteTrace(int64 tensor_id) {
|
||||
auto it = tensor_usage_.find(tensor_id);
|
||||
if (it == tensor_usage_.end()) {
|
||||
return;
|
||||
}
|
||||
it->second--;
|
||||
if (it->second != 0) {
|
||||
return;
|
||||
}
|
||||
tensor_usage_.erase(it);
|
||||
auto tensor_op_it = tensor_tape_.find(tensor_id);
|
||||
if (tensor_op_it == tensor_tape_.end()) {
|
||||
return;
|
||||
}
|
||||
const int64 op_id = tensor_op_it->second;
|
||||
if (op_id == -1) {
|
||||
// Do not delete watched tensors.
|
||||
return;
|
||||
}
|
||||
tensor_tape_.erase(tensor_op_it);
|
||||
auto op_it = op_tape_.find(op_id);
|
||||
CHECK(op_it != op_tape_.end());
|
||||
for (const auto& output : op_it->second.output_tensor_info) {
|
||||
if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
|
||||
// Found a usage for an output, so cannot delete the op.
|
||||
return;
|
||||
}
|
||||
}
|
||||
for (int64 id : op_it->second.input_tensor_id) {
|
||||
DeleteTrace(id);
|
||||
}
|
||||
op_it->second.backward_function_deleter();
|
||||
op_tape_.erase(op_it);
|
||||
}
|
||||
|
||||
// Terminology:
|
||||
//
|
||||
// - op: a possibly composite operation, which has an entry in the tape
|
||||
// - target: dy in dx/dy
|
||||
// - source: dx in dx/dy
|
||||
// - tensor: one of the many inputs or outputs of an operation
|
||||
//
|
||||
// Below here we do the gradient algorithm. It works as follows:
|
||||
//
|
||||
// First we filter the tape to just the subset of operations we want to
|
||||
// differentiate. In the process of doing so we count how many times each Tensor
|
||||
// is used as an input to an op (so we know when we're done computing gradients
|
||||
// for that Tensor). We also count, for each tape entry, how many of its output
|
||||
// Tensors need gradients to be computed (Tensors which are not used do not need
|
||||
// any gradients to be computed).
|
||||
//
|
||||
// Finally, we start a backprop stack with a set of tape entries for which we
|
||||
// have all gradients available. This set usually is a subset of the set of
|
||||
// targets (not all since targets which have outputs in the tape will not have
|
||||
// gradients available initially).
|
||||
//
|
||||
// Then we repeatedly pop an entry from the stack, run its backprop, and update
|
||||
// the gradients of its inputs. Once we have computed all gradients for a single
|
||||
// input we can mark this input as done, and this can trigger adding an entry to
|
||||
// the stack if all outputs of that entry are now done.
|
||||
//
|
||||
// When the stack is empty we have gradients for all tensors we're interested
|
||||
// in.
|
||||
|
||||
struct BackpropInitialState {
|
||||
OpTape op_tape;
|
||||
|
||||
// Map from tensor ID to how many references still exist for this tensor in
|
||||
// the tape.
|
||||
std::unordered_map<int64, int64> tensor_usage_counts;
|
||||
|
||||
// Maps from op ID to how many output tensors of this op still need to have
|
||||
// their gradients computed.
|
||||
std::unordered_map<int64, int64> op_missing_tensor;
|
||||
};
|
||||
|
||||
BackpropInitialState PrepareBackprop(
|
||||
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
|
||||
OpTape op_tape, const std::unordered_set<int64>& sources_set) {
|
||||
std::vector<int64> tensor_stack;
|
||||
tensor_stack.reserve(target.size());
|
||||
for (auto t : target) {
|
||||
tensor_stack.push_back(t);
|
||||
}
|
||||
BackpropInitialState result;
|
||||
while (!tensor_stack.empty()) {
|
||||
int64 tensor_id = tensor_stack.back();
|
||||
tensor_stack.pop_back();
|
||||
auto op_id_it = tensor_tape.find(tensor_id);
|
||||
if (op_id_it == tensor_tape.end()) {
|
||||
continue;
|
||||
}
|
||||
int64 op_id = op_id_it->second;
|
||||
auto op_it = op_tape.find(op_id);
|
||||
auto result_op_it = result.op_tape.find(op_id);
|
||||
if (op_id == -1 || op_it == op_tape.end() ||
|
||||
result_op_it != result.op_tape.end()) {
|
||||
continue;
|
||||
}
|
||||
CHECK(result.op_tape.emplace(op_id, op_it->second).second);
|
||||
for (auto it : op_it->second.input_tensor_id) {
|
||||
auto count_it = result.tensor_usage_counts.find(it);
|
||||
if (count_it != result.tensor_usage_counts.end()) {
|
||||
count_it->second++;
|
||||
} else {
|
||||
result.tensor_usage_counts[it] = 1;
|
||||
if (sources_set.find(it) == sources_set.end() &&
|
||||
tensor_tape.find(it) != tensor_tape.end()) {
|
||||
tensor_stack.push_back(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
op_tape.erase(op_it);
|
||||
}
|
||||
for (auto& pair : result.tensor_usage_counts) {
|
||||
auto it = tensor_tape.find(pair.first);
|
||||
if (it != tensor_tape.end() && it->second != -1) {
|
||||
result.op_missing_tensor[it->second] += 1;
|
||||
}
|
||||
}
|
||||
// Call destructors for all unneeded gradient functions.
|
||||
for (const auto& op_pair : op_tape) {
|
||||
op_pair.second.backward_function_deleter();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64> InitialStack(
|
||||
const OpTape& op_tape,
|
||||
const std::unordered_map<int64, int64>& op_missing_tensor) {
|
||||
std::vector<int64> result;
|
||||
for (auto& op_entry : op_tape) {
|
||||
if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
|
||||
result.push_back(op_entry.first);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Status InitialGradients(const VSpace& vspace, gtl::ArraySlice<void*> target,
|
||||
gtl::ArraySlice<void*> output_gradients,
|
||||
std::unordered_map<int64, int64> tensor_usage_counts,
|
||||
std::unordered_map<int64, std::vector<void*>>* result) {
|
||||
for (int i = 0; i < target.size(); ++i) {
|
||||
int64 id = vspace.TensorId(target[i]);
|
||||
if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
|
||||
if (!output_gradients.empty() && output_gradients[i] != nullptr) {
|
||||
// TODO(apassos) figure out how to print debugging information here.
|
||||
return errors::InvalidArgument(
|
||||
"A gradient was provided for a tensor which is used as part of the "
|
||||
"computation.");
|
||||
}
|
||||
} else {
|
||||
if (output_gradients.empty() || output_gradients[i] == nullptr) {
|
||||
(*result)[id].push_back(vspace.OnesLike(target[i]));
|
||||
} else {
|
||||
(*result)[id].push_back(output_gradients[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If over kMinAggregateCount gradients are accumulated and the total
|
||||
// memory consumption is over kMinAggregateBytes, do an early aggregation
|
||||
// so as to release the gradient tensor to save memory.
|
||||
static const int kMinAggregateCount = 4;
|
||||
static const int kMinAggregateBytes = 128 * 1024 * 1024;
|
||||
|
||||
Status GradientTape::Gradient(const VSpace& vspace,
|
||||
gtl::ArraySlice<void*> target,
|
||||
gtl::ArraySlice<void*> sources,
|
||||
gtl::ArraySlice<void*> output_gradients,
|
||||
std::vector<void*>* result) {
|
||||
std::vector<int64> id_sources;
|
||||
id_sources.reserve(sources.size());
|
||||
for (void* s : sources) {
|
||||
id_sources.push_back(vspace.TensorId(s));
|
||||
}
|
||||
std::unordered_set<int64> sources_set(id_sources.begin(), id_sources.end());
|
||||
std::vector<int64> id_targets;
|
||||
id_sources.reserve(target.size());
|
||||
for (void* t : target) {
|
||||
id_targets.push_back(vspace.TensorId(t));
|
||||
}
|
||||
BackpropInitialState state = PrepareBackprop(
|
||||
id_targets, tensor_tape_, std::move(op_tape_), sources_set);
|
||||
std::vector<int64> op_stack =
|
||||
InitialStack(state.op_tape, state.op_missing_tensor);
|
||||
std::unordered_map<int64, std::vector<void*>> gradients;
|
||||
Status s = InitialGradients(vspace, target, output_gradients,
|
||||
state.tensor_usage_counts, &gradients);
|
||||
auto cleanup = [&state]() {
|
||||
// Release all backprop functions
|
||||
for (const auto& pair : state.op_tape) {
|
||||
pair.second.backward_function_deleter();
|
||||
}
|
||||
};
|
||||
if (!s.ok()) {
|
||||
cleanup();
|
||||
return s;
|
||||
}
|
||||
std::unordered_map<int64, int64> gradients_size;
|
||||
// TODO(apassos) multiple threads could be dequeuing from op_stack at the same
|
||||
// time, for better CPU backprop performance.
|
||||
VLOG(1) << "Initial stack:";
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (auto t : op_stack) {
|
||||
VLOG(1) << " " << t;
|
||||
}
|
||||
}
|
||||
std::unordered_map<string, std::unordered_set<int>>
|
||||
functions_accept_none_for_indices({
|
||||
{"SoftmaxCrossEntropyWithLogits", {1}},
|
||||
{"FusedBatchNorm", {1, 2, 3, 4}},
|
||||
});
|
||||
while (!op_stack.empty()) {
|
||||
const int64 op = op_stack.back();
|
||||
VLOG(1) << "Popped " << op;
|
||||
op_stack.pop_back();
|
||||
auto op_it = state.op_tape.find(op);
|
||||
if (op_it == state.op_tape.end()) {
|
||||
// It is possible for ops to end up on the stack if they are unrelated to
|
||||
// the target; we should just skip them.
|
||||
continue;
|
||||
}
|
||||
auto trace = std::move(op_it->second);
|
||||
state.op_tape.erase(op_it);
|
||||
std::vector<void*> out_gradients;
|
||||
out_gradients.reserve(trace.output_tensor_info.size());
|
||||
for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
|
||||
const int64 id = trace.output_tensor_info[i].id;
|
||||
auto grad_it = gradients.find(id);
|
||||
if (grad_it == gradients.end()) {
|
||||
auto func_name_it =
|
||||
functions_accept_none_for_indices.find(trace.op_type);
|
||||
if (func_name_it != functions_accept_none_for_indices.end() &&
|
||||
func_name_it->second.find(i) != func_name_it->second.end()) {
|
||||
out_gradients.push_back(nullptr);
|
||||
} else {
|
||||
out_gradients.push_back(
|
||||
vspace.Zeros(trace.output_tensor_info[i].shape,
|
||||
trace.output_tensor_info[i].dtype));
|
||||
}
|
||||
} else {
|
||||
out_gradients.push_back(vspace.AggregateGradients(grad_it->second));
|
||||
if (sources_set.find(grad_it->first) == sources_set.end()) {
|
||||
gradients.erase(grad_it);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<void*> in_gradients;
|
||||
Status s = vspace.CallBackwardFunction(trace.backward_function,
|
||||
out_gradients, &in_gradients);
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Gradient function failed.";
|
||||
cleanup();
|
||||
return s;
|
||||
}
|
||||
VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
|
||||
<< trace.input_tensor_id.size() << " sources";
|
||||
for (int i = 0; i < in_gradients.size(); ++i) {
|
||||
const int64 id = trace.input_tensor_id[i];
|
||||
if (in_gradients[i] != nullptr) {
|
||||
auto& unaggregated_grads = gradients[id];
|
||||
unaggregated_grads.push_back(in_gradients[i]);
|
||||
if (unaggregated_grads.size() > kMinAggregateCount) {
|
||||
auto size_it = gradients_size.find(id);
|
||||
int64 size;
|
||||
if (size_it == gradients_size.end()) {
|
||||
size = vspace.NumElements(unaggregated_grads[0]);
|
||||
gradients_size.emplace(id, size);
|
||||
} else {
|
||||
size = size_it->second;
|
||||
}
|
||||
if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
|
||||
void* tensor = vspace.AggregateGradients(unaggregated_grads);
|
||||
unaggregated_grads.clear();
|
||||
unaggregated_grads.push_back(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto usage_count_it = state.tensor_usage_counts.find(id);
|
||||
if (usage_count_it == state.tensor_usage_counts.end()) {
|
||||
VLOG(1) << "Tensor " << id << " not used";
|
||||
continue;
|
||||
}
|
||||
usage_count_it->second--;
|
||||
if (usage_count_it->second > 0) {
|
||||
VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
|
||||
continue;
|
||||
}
|
||||
auto tape_it = tensor_tape_.find(id);
|
||||
if (tape_it == tensor_tape_.end()) {
|
||||
VLOG(1) << "Tensor " << id
|
||||
<< " has no associated op. Deleting gradient";
|
||||
auto grad_it = gradients.find(id);
|
||||
if (grad_it != gradients.end()) {
|
||||
for (auto g : grad_it->second) {
|
||||
vspace.DeleteTensor(g);
|
||||
}
|
||||
gradients.erase(grad_it);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
const int64 op_id = tape_it->second;
|
||||
if (op_id == -1) {
|
||||
VLOG(1) << "Tensor " << id << " is source";
|
||||
continue;
|
||||
}
|
||||
auto missing_it = state.op_missing_tensor.find(op_id);
|
||||
if (missing_it != state.op_missing_tensor.end()) {
|
||||
missing_it->second--;
|
||||
VLOG(1) << "Op " << op_id << " missing " << missing_it->second
|
||||
<< " output gradients";
|
||||
if (missing_it->second == 0) {
|
||||
op_stack.push_back(op_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK(state.op_tape.empty());
|
||||
result->reserve(sources.size());
|
||||
for (auto is : id_sources) {
|
||||
auto grad_it = gradients.find(is);
|
||||
if (grad_it == gradients.end()) {
|
||||
result->push_back(nullptr);
|
||||
} else {
|
||||
if (grad_it->second.size() == 1) {
|
||||
result->push_back(grad_it->second[0]);
|
||||
} else {
|
||||
result->push_back(vspace.AggregateGradients(grad_it->second));
|
||||
}
|
||||
gradients.erase(grad_it);
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Final gradients size: " << gradients.size();
|
||||
for (auto grad_pair : gradients) {
|
||||
for (const auto& g : grad_pair.second) {
|
||||
vspace.DeleteTensor(g);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
// maintains the data structures required to do so.
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -36,13 +37,14 @@ struct TapeTensor {
|
||||
};
|
||||
|
||||
// Represents an entry in the tape.
|
||||
template <typename BackwardFunction>
|
||||
struct OpTapeEntry {
|
||||
string op_type;
|
||||
std::vector<TapeTensor> output_tensor_info;
|
||||
std::vector<int64> input_tensor_id;
|
||||
|
||||
// TODO(apassos) consider narrowing down this interface.
|
||||
void* backward_function;
|
||||
BackwardFunction* backward_function;
|
||||
|
||||
// Should be called before deleting the backward function. TODO(apassos) use
|
||||
// unique_ptrs to ensure this happens.
|
||||
@ -55,51 +57,67 @@ struct OpTapeEntry {
|
||||
using TensorTape = std::unordered_map<int64, int64>;
|
||||
|
||||
// Map from operation-id to tape entry.
|
||||
using OpTape = std::unordered_map<int64, OpTapeEntry>;
|
||||
template <typename BackwardFunction>
|
||||
using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
|
||||
|
||||
// Operations the tape needs to perform on tensors to do backpropagation. Named
|
||||
// "vspace" because a subset of these are related to a vector space, such as
|
||||
// adding gradients, getting zeroes, etc. Currently cannot be implemented
|
||||
// without using tensorflow python code, hence left unspecified here.
|
||||
//
|
||||
// We currently use void* for tensors, backward functions, and gradients (which
|
||||
// can be but are not required to be tensors). TODO(apassos) replace this first
|
||||
// with templates to allow for pyobject specialization in the client followed by
|
||||
// a TFE_TensorHandle specialization, which is blocked by quite a few things
|
||||
// still.
|
||||
// Tensor is a representation of a tensor. We need to take its ID, and it needs
|
||||
// to match IDs in the tape.
|
||||
//
|
||||
// Gradient is the type returned by gradient functions. In Python TF it's either
|
||||
// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
|
||||
// to allow their size to be computed and they need to be passable to a backward
|
||||
// function and deleted (as the backprop code creates lots of gradients the user
|
||||
// is not interested in).
|
||||
//
|
||||
// BackwardFunction needs to be a closure which stores intermediate activations
|
||||
// from the forward computation and calls a vector-jacobian product function
|
||||
// (also known as adjoint function) to compute, given downstream gradients,
|
||||
// upstream gradients.
|
||||
//
|
||||
// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
|
||||
// specialization, which is blocked by quite a few things needing to loop back
|
||||
// into python now.
|
||||
template <typename Tensor, typename Gradient, typename BackwardFunction>
|
||||
class VSpace {
|
||||
public:
|
||||
virtual ~VSpace() {}
|
||||
|
||||
// Returns the number of elements in the tensor.
|
||||
virtual int64 NumElements(void* tensor) const = 0;
|
||||
// Returns the number of elements in the gradient tensor.
|
||||
virtual int64 NumElements(Gradient* tensor) const = 0;
|
||||
|
||||
// Consumes references to the tensors in the gradient_tensors list and returns
|
||||
// a tensor with the result.
|
||||
virtual void* AggregateGradients(
|
||||
gtl::ArraySlice<void*> gradient_tensors) const = 0;
|
||||
virtual Gradient* AggregateGradients(
|
||||
gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
|
||||
|
||||
// Returns a tensor of the right shape and dtype filled with zeros.
|
||||
virtual void* Zeros(TensorShape shape, DataType dtype) const = 0;
|
||||
virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
|
||||
|
||||
// Returns a Tensor which is filled with ones and like the input.
|
||||
virtual void* OnesLike(void*) const = 0;
|
||||
virtual Gradient* OnesLike(Tensor*) const = 0;
|
||||
|
||||
// Returns an integer which is a unique-to-within-this-program handle for this
|
||||
// tensor.
|
||||
virtual int64 TensorId(void* tensor) const = 0;
|
||||
virtual int64 TensorId(Tensor* tensor) const = 0;
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
virtual Status CallBackwardFunction(void* backward_function,
|
||||
gtl::ArraySlice<void*> output_gradients,
|
||||
std::vector<void*>* result) const = 0;
|
||||
virtual Status CallBackwardFunction(
|
||||
BackwardFunction* backward_function,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result) const = 0;
|
||||
|
||||
// Deletes the input tensor.
|
||||
virtual void DeleteTensor(void* tensor) const = 0;
|
||||
virtual void DeleteGradient(Gradient* gradient) const = 0;
|
||||
};
|
||||
|
||||
// Traces the execution of operations, doing eager garbage collection, and
|
||||
// exporting a full trace so other code can do backpropagation. Not thread-safe.
|
||||
template <typename Tensor, typename Gradient, typename BackwardFunction>
|
||||
class GradientTape {
|
||||
public:
|
||||
GradientTape() {}
|
||||
@ -116,7 +134,7 @@ class GradientTape {
|
||||
void RecordOperation(const string& op_type,
|
||||
gtl::ArraySlice<TapeTensor> output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id,
|
||||
void* backward_function,
|
||||
BackwardFunction* backward_function,
|
||||
const std::function<void()>& backward_function_deleter);
|
||||
|
||||
void DeleteTrace(int64 tensor_id);
|
||||
@ -125,14 +143,15 @@ class GradientTape {
|
||||
// once) and produces the gradient of the target tensors with respect to the
|
||||
// source tensors. The output gradients are used if not empty and not
|
||||
// null. The result is populated with one tensor per target element.
|
||||
Status Gradient(const VSpace& vspace, gtl::ArraySlice<void*> target,
|
||||
gtl::ArraySlice<void*> sources,
|
||||
gtl::ArraySlice<void*> output_gradients,
|
||||
std::vector<void*>* result);
|
||||
Status ComputeGradient(
|
||||
const VSpace<Tensor, Gradient, BackwardFunction>& vspace,
|
||||
gtl::ArraySlice<Tensor*> target, gtl::ArraySlice<Tensor*> sources,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result);
|
||||
|
||||
private:
|
||||
TensorTape tensor_tape_;
|
||||
OpTape op_tape_;
|
||||
OpTape<BackwardFunction> op_tape_;
|
||||
int64 next_op_id_{0};
|
||||
|
||||
// Map from tensor id to number of remaining usages (i.e. how many entries in
|
||||
@ -140,6 +159,412 @@ class GradientTape {
|
||||
std::unordered_map<int64, int64> tensor_usage_;
|
||||
};
|
||||
|
||||
// Template instantiations here
|
||||
|
||||
template <typename Tensor, typename Gradient, typename BackwardFunction>
|
||||
bool GradientTape<Tensor, Gradient, BackwardFunction>::ShouldRecord(
|
||||
gtl::ArraySlice<int64> tensor_ids) {
|
||||
for (int64 i : tensor_ids) {
|
||||
if (tensor_tape_.find(i) != tensor_tape_.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename Tensor, typename Gradient, typename BackwardFunction>
|
||||
void GradientTape<Tensor, Gradient, BackwardFunction>::Watch(int64 tensor_id) {
|
||||
tensor_tape_.emplace(tensor_id, -1);
|
||||
}
|
||||
|
||||
template <typename Tensor, typename Gradient, typename BackwardFunction>
|
||||
void GradientTape<Tensor, Gradient, BackwardFunction>::RecordOperation(
|
||||
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
|
||||
const std::function<void()>& backward_function_deleter) {
|
||||
if (!ShouldRecord(input_tensor_id)) {
|
||||
backward_function_deleter();
|
||||
return;
|
||||
}
|
||||
std::vector<int64> ids;
|
||||
ids.reserve(input_tensor_id.size());
|
||||
for (int64 i : input_tensor_id) {
|
||||
tensor_usage_[i]++;
|
||||
ids.push_back(i);
|
||||
}
|
||||
const int64 op_id = next_op_id_++;
|
||||
std::vector<TapeTensor> tensors;
|
||||
tensors.reserve(output_tensors.size());
|
||||
for (const TapeTensor& o : output_tensors) {
|
||||
// Note: the tensor can have already been watched and hence be in the tape,
|
||||
// so we cannot check that we're inserting it here.
|
||||
tensor_tape_[o.id] = op_id;
|
||||
tensor_usage_[o.id] = 1;
|
||||
tensors.push_back(o);
|
||||
}
|
||||
op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
|
||||
op_type, tensors, ids, backward_function, backward_function_deleter};
|
||||
}
|
||||
|
||||
template <typename Tensor, typename Gradient, typename BackwardFunction>
|
||||
void GradientTape<Tensor, Gradient, BackwardFunction>::DeleteTrace(
|
||||
int64 tensor_id) {
|
||||
auto it = tensor_usage_.find(tensor_id);
|
||||
if (it == tensor_usage_.end()) {
|
||||
return;
|
||||
}
|
||||
it->second--;
|
||||
if (it->second != 0) {
|
||||
return;
|
||||
}
|
||||
tensor_usage_.erase(it);
|
||||
auto tensor_op_it = tensor_tape_.find(tensor_id);
|
||||
if (tensor_op_it == tensor_tape_.end()) {
|
||||
return;
|
||||
}
|
||||
const int64 op_id = tensor_op_it->second;
|
||||
if (op_id == -1) {
|
||||
// Do not delete watched tensors.
|
||||
return;
|
||||
}
|
||||
tensor_tape_.erase(tensor_op_it);
|
||||
auto op_it = op_tape_.find(op_id);
|
||||
CHECK(op_it != op_tape_.end());
|
||||
for (const auto& output : op_it->second.output_tensor_info) {
|
||||
if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
|
||||
// Found a usage for an output, so cannot delete the op.
|
||||
return;
|
||||
}
|
||||
}
|
||||
for (int64 id : op_it->second.input_tensor_id) {
|
||||
DeleteTrace(id);
|
||||
}
|
||||
op_it->second.backward_function_deleter();
|
||||
op_tape_.erase(op_it);
|
||||
}
|
||||
|
||||
// Terminology:
|
||||
//
|
||||
// - op: a possibly composite operation, which has an entry in the tape
|
||||
// - target: dy in dx/dy
|
||||
// - source: dx in dx/dy
|
||||
// - tensor: one of the many inputs or outputs of an operation
|
||||
//
|
||||
// Below here we do the gradient algorithm. It works as follows:
|
||||
//
|
||||
// First we filter the tape to just the subset of operations we want to
|
||||
// differentiate. In the process of doing so we count how many times each Tensor
|
||||
// is used as an input to an op (so we know when we're done computing gradients
|
||||
// for that Tensor). We also count, for each tape entry, how many of its output
|
||||
// Tensors need gradients to be computed (Tensors which are not used do not need
|
||||
// any gradients to be computed).
|
||||
//
|
||||
// Finally, we start a backprop stack with a set of tape entries for which we
|
||||
// have all gradients available. This set usually is a subset of the set of
|
||||
// targets (not all since targets which have outputs in the tape will not have
|
||||
// gradients available initially).
|
||||
//
|
||||
// Then we repeatedly pop an entry from the stack, run its backprop, and update
|
||||
// the gradients of its inputs. Once we have computed all gradients for a single
|
||||
// input we can mark this input as done, and this can trigger adding an entry to
|
||||
// the stack if all outputs of that entry are now done.
|
||||
//
|
||||
// When the stack is empty we have gradients for all tensors we're interested
|
||||
// in.
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename BackwardFunction>
|
||||
struct BackpropInitialState {
|
||||
OpTape<BackwardFunction> op_tape;
|
||||
|
||||
// Map from tensor ID to how many references still exist for this tensor in
|
||||
// the tape.
|
||||
std::unordered_map<int64, int64> tensor_usage_counts;
|
||||
|
||||
// Maps from op ID to how many output tensors of this op still need to have
|
||||
// their gradients computed.
|
||||
std::unordered_map<int64, int64> op_missing_tensor;
|
||||
};
|
||||
|
||||
template <typename BackwardFunction>
|
||||
BackpropInitialState<BackwardFunction> PrepareBackprop(
|
||||
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
|
||||
OpTape<BackwardFunction> op_tape,
|
||||
const std::unordered_set<int64>& sources_set) {
|
||||
std::vector<int64> tensor_stack;
|
||||
tensor_stack.reserve(target.size());
|
||||
for (auto t : target) {
|
||||
tensor_stack.push_back(t);
|
||||
}
|
||||
BackpropInitialState<BackwardFunction> result;
|
||||
while (!tensor_stack.empty()) {
|
||||
int64 tensor_id = tensor_stack.back();
|
||||
tensor_stack.pop_back();
|
||||
auto op_id_it = tensor_tape.find(tensor_id);
|
||||
if (op_id_it == tensor_tape.end()) {
|
||||
continue;
|
||||
}
|
||||
int64 op_id = op_id_it->second;
|
||||
auto op_it = op_tape.find(op_id);
|
||||
auto result_op_it = result.op_tape.find(op_id);
|
||||
if (op_id == -1 || op_it == op_tape.end() ||
|
||||
result_op_it != result.op_tape.end()) {
|
||||
continue;
|
||||
}
|
||||
CHECK(result.op_tape.emplace(op_id, op_it->second).second);
|
||||
for (auto it : op_it->second.input_tensor_id) {
|
||||
auto count_it = result.tensor_usage_counts.find(it);
|
||||
if (count_it != result.tensor_usage_counts.end()) {
|
||||
count_it->second++;
|
||||
} else {
|
||||
result.tensor_usage_counts[it] = 1;
|
||||
if (sources_set.find(it) == sources_set.end() &&
|
||||
tensor_tape.find(it) != tensor_tape.end()) {
|
||||
tensor_stack.push_back(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
op_tape.erase(op_it);
|
||||
}
|
||||
for (auto& pair : result.tensor_usage_counts) {
|
||||
auto it = tensor_tape.find(pair.first);
|
||||
if (it != tensor_tape.end() && it->second != -1) {
|
||||
result.op_missing_tensor[it->second] += 1;
|
||||
}
|
||||
}
|
||||
// Call destructors for all unneeded gradient functions.
|
||||
for (const auto& op_pair : op_tape) {
|
||||
op_pair.second.backward_function_deleter();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BackwardFunction>
|
||||
std::vector<int64> InitialStack(
|
||||
const OpTape<BackwardFunction>& op_tape,
|
||||
const std::unordered_map<int64, int64>& op_missing_tensor) {
|
||||
std::vector<int64> result;
|
||||
for (auto& op_entry : op_tape) {
|
||||
if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
|
||||
result.push_back(op_entry.first);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Tensor, typename Gradient, typename BackwardFunction>
|
||||
Status InitialGradients(
|
||||
const VSpace<Tensor, Gradient, BackwardFunction>& vspace,
|
||||
gtl::ArraySlice<Tensor*> target,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::unordered_map<int64, int64> tensor_usage_counts,
|
||||
std::unordered_map<int64, std::vector<Gradient*>>* result) {
|
||||
for (int i = 0; i < target.size(); ++i) {
|
||||
int64 id = vspace.TensorId(target[i]);
|
||||
if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
|
||||
if (!output_gradients.empty() && output_gradients[i] != nullptr) {
|
||||
// TODO(apassos) figure out how to print debugging information here.
|
||||
return errors::InvalidArgument(
|
||||
"A gradient was provided for a tensor which is used as part of the "
|
||||
"computation.");
|
||||
}
|
||||
} else {
|
||||
if (output_gradients.empty() || output_gradients[i] == nullptr) {
|
||||
(*result)[id].push_back(vspace.OnesLike(target[i]));
|
||||
} else {
|
||||
(*result)[id].push_back(output_gradients[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// If over kMinAggregateCount gradients are accumulated and the total
|
||||
// memory consumption is over kMinAggregateBytes, do an early aggregation
|
||||
// so as to release the gradient tensor to save memory.
|
||||
constexpr int kMinAggregateCount = 4;
|
||||
constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
|
||||
|
||||
template <typename Tensor, typename Gradient, typename BackwardFunction>
|
||||
Status GradientTape<Tensor, Gradient, BackwardFunction>::ComputeGradient(
|
||||
const VSpace<Tensor, Gradient, BackwardFunction>& vspace,
|
||||
gtl::ArraySlice<Tensor*> target, gtl::ArraySlice<Tensor*> sources,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result) {
|
||||
std::vector<int64> id_sources;
|
||||
id_sources.reserve(sources.size());
|
||||
for (Tensor* s : sources) {
|
||||
id_sources.push_back(vspace.TensorId(s));
|
||||
}
|
||||
std::unordered_set<int64> sources_set(id_sources.begin(), id_sources.end());
|
||||
std::vector<int64> id_targets;
|
||||
id_sources.reserve(target.size());
|
||||
for (Tensor* t : target) {
|
||||
id_targets.push_back(vspace.TensorId(t));
|
||||
}
|
||||
BackpropInitialState<BackwardFunction> state = PrepareBackprop(
|
||||
id_targets, tensor_tape_, std::move(op_tape_), sources_set);
|
||||
std::vector<int64> op_stack =
|
||||
InitialStack(state.op_tape, state.op_missing_tensor);
|
||||
std::unordered_map<int64, std::vector<Gradient*>> gradients;
|
||||
Status s = InitialGradients(vspace, target, output_gradients,
|
||||
state.tensor_usage_counts, &gradients);
|
||||
auto cleanup = [&state]() {
|
||||
// Release all backprop functions
|
||||
for (const auto& pair : state.op_tape) {
|
||||
pair.second.backward_function_deleter();
|
||||
}
|
||||
};
|
||||
if (!s.ok()) {
|
||||
cleanup();
|
||||
return s;
|
||||
}
|
||||
std::unordered_map<int64, int64> gradients_size;
|
||||
// TODO(apassos) multiple threads could be dequeuing from op_stack at the same
|
||||
// time, for better CPU backprop performance.
|
||||
VLOG(1) << "Initial stack:";
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (auto t : op_stack) {
|
||||
VLOG(1) << " " << t;
|
||||
}
|
||||
}
|
||||
std::unordered_map<string, std::unordered_set<int>>
|
||||
functions_accept_none_for_indices({
|
||||
{"SoftmaxCrossEntropyWithLogits", {1}},
|
||||
{"FusedBatchNorm", {1, 2, 3, 4}},
|
||||
});
|
||||
while (!op_stack.empty()) {
|
||||
const int64 op = op_stack.back();
|
||||
VLOG(1) << "Popped " << op;
|
||||
op_stack.pop_back();
|
||||
auto op_it = state.op_tape.find(op);
|
||||
if (op_it == state.op_tape.end()) {
|
||||
// It is possible for ops to end up on the stack if they are unrelated to
|
||||
// the target; we should just skip them.
|
||||
continue;
|
||||
}
|
||||
auto trace = std::move(op_it->second);
|
||||
state.op_tape.erase(op_it);
|
||||
std::vector<Gradient*> out_gradients;
|
||||
out_gradients.reserve(trace.output_tensor_info.size());
|
||||
for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
|
||||
const int64 id = trace.output_tensor_info[i].id;
|
||||
auto grad_it = gradients.find(id);
|
||||
if (grad_it == gradients.end()) {
|
||||
auto func_name_it =
|
||||
functions_accept_none_for_indices.find(trace.op_type);
|
||||
if (func_name_it != functions_accept_none_for_indices.end() &&
|
||||
func_name_it->second.find(i) != func_name_it->second.end()) {
|
||||
out_gradients.push_back(nullptr);
|
||||
} else {
|
||||
out_gradients.push_back(
|
||||
vspace.Zeros(trace.output_tensor_info[i].shape,
|
||||
trace.output_tensor_info[i].dtype));
|
||||
}
|
||||
} else {
|
||||
out_gradients.push_back(vspace.AggregateGradients(grad_it->second));
|
||||
if (sources_set.find(grad_it->first) == sources_set.end()) {
|
||||
gradients.erase(grad_it);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<Gradient*> in_gradients;
|
||||
Status s = vspace.CallBackwardFunction(trace.backward_function,
|
||||
out_gradients, &in_gradients);
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Gradient function failed.";
|
||||
cleanup();
|
||||
return s;
|
||||
}
|
||||
VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
|
||||
<< trace.input_tensor_id.size() << " sources";
|
||||
for (int i = 0; i < in_gradients.size(); ++i) {
|
||||
const int64 id = trace.input_tensor_id[i];
|
||||
if (in_gradients[i] != nullptr) {
|
||||
auto& unaggregated_grads = gradients[id];
|
||||
unaggregated_grads.push_back(in_gradients[i]);
|
||||
if (unaggregated_grads.size() > kMinAggregateCount) {
|
||||
auto size_it = gradients_size.find(id);
|
||||
int64 size;
|
||||
if (size_it == gradients_size.end()) {
|
||||
size = vspace.NumElements(unaggregated_grads[0]);
|
||||
gradients_size.emplace(id, size);
|
||||
} else {
|
||||
size = size_it->second;
|
||||
}
|
||||
if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
|
||||
Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
|
||||
unaggregated_grads.clear();
|
||||
unaggregated_grads.push_back(grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto usage_count_it = state.tensor_usage_counts.find(id);
|
||||
if (usage_count_it == state.tensor_usage_counts.end()) {
|
||||
VLOG(1) << "Tensor " << id << " not used";
|
||||
continue;
|
||||
}
|
||||
usage_count_it->second--;
|
||||
if (usage_count_it->second > 0) {
|
||||
VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
|
||||
continue;
|
||||
}
|
||||
auto tape_it = tensor_tape_.find(id);
|
||||
if (tape_it == tensor_tape_.end()) {
|
||||
VLOG(1) << "Tensor " << id
|
||||
<< " has no associated op. Deleting gradient";
|
||||
auto grad_it = gradients.find(id);
|
||||
if (grad_it != gradients.end()) {
|
||||
for (auto g : grad_it->second) {
|
||||
vspace.DeleteGradient(g);
|
||||
}
|
||||
gradients.erase(grad_it);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
const int64 op_id = tape_it->second;
|
||||
if (op_id == -1) {
|
||||
VLOG(1) << "Tensor " << id << " is source";
|
||||
continue;
|
||||
}
|
||||
auto missing_it = state.op_missing_tensor.find(op_id);
|
||||
if (missing_it != state.op_missing_tensor.end()) {
|
||||
missing_it->second--;
|
||||
VLOG(1) << "Op " << op_id << " missing " << missing_it->second
|
||||
<< " output gradients";
|
||||
if (missing_it->second == 0) {
|
||||
op_stack.push_back(op_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK(state.op_tape.empty());
|
||||
result->reserve(sources.size());
|
||||
for (auto is : id_sources) {
|
||||
auto grad_it = gradients.find(is);
|
||||
if (grad_it == gradients.end()) {
|
||||
result->push_back(nullptr);
|
||||
} else {
|
||||
if (grad_it->second.size() == 1) {
|
||||
result->push_back(grad_it->second[0]);
|
||||
} else {
|
||||
result->push_back(vspace.AggregateGradients(grad_it->second));
|
||||
}
|
||||
gradients.erase(grad_it);
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Final gradients size: " << gradients.size();
|
||||
for (auto grad_pair : gradients) {
|
||||
for (const auto& g : grad_pair.second) {
|
||||
vspace.DeleteGradient(g);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -443,10 +443,13 @@ void TFE_DeleteContextCapsule(PyObject* context) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
using GradientTape =
|
||||
tensorflow::eager::GradientTape<PyObject, PyObject, PyObject>;
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
/* Type-specific fields go here. */
|
||||
tensorflow::eager::GradientTape* tape;
|
||||
GradientTape* tape;
|
||||
} TFE_Py_Tape;
|
||||
|
||||
static void TFE_Py_Tape_Delete(PyObject* tape) {
|
||||
@ -481,7 +484,7 @@ PyObject* TFE_Py_NewTape() {
|
||||
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
|
||||
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
|
||||
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
|
||||
tape->tape = new tensorflow::eager::GradientTape();
|
||||
tape->tape = new GradientTape();
|
||||
return reinterpret_cast<PyObject*>(tape);
|
||||
}
|
||||
|
||||
@ -627,9 +630,8 @@ void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id) {
|
||||
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->DeleteTrace(tensor_id);
|
||||
}
|
||||
|
||||
// TODO(apassos): cache the attribute lookups as member variables and decref
|
||||
// them in the destructor.
|
||||
class PyVSpace : public tensorflow::eager::VSpace {
|
||||
class PyVSpace
|
||||
: public tensorflow::eager::VSpace<PyObject, PyObject, PyObject> {
|
||||
public:
|
||||
explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
|
||||
|
||||
@ -661,7 +663,7 @@ class PyVSpace : public tensorflow::eager::VSpace {
|
||||
Py_XDECREF(ones_like_);
|
||||
}
|
||||
|
||||
tensorflow::int64 NumElements(void* tensor) const final {
|
||||
tensorflow::int64 NumElements(PyObject* tensor) const final {
|
||||
PyObject* arglist =
|
||||
Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
|
||||
PyObject* result = PyEval_CallObject(num_elements_, arglist);
|
||||
@ -671,8 +673,8 @@ class PyVSpace : public tensorflow::eager::VSpace {
|
||||
return r;
|
||||
}
|
||||
|
||||
void* AggregateGradients(
|
||||
tensorflow::gtl::ArraySlice<void*> gradient_tensors) const final {
|
||||
PyObject* AggregateGradients(
|
||||
tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
|
||||
PyObject* list = PyList_New(gradient_tensors.size());
|
||||
for (int i = 0; i < gradient_tensors.size(); ++i) {
|
||||
// Note: stealing a reference to the gradient tensors.
|
||||
@ -689,8 +691,8 @@ class PyVSpace : public tensorflow::eager::VSpace {
|
||||
return result;
|
||||
}
|
||||
|
||||
void* Zeros(tensorflow::TensorShape shape,
|
||||
tensorflow::DataType dtype) const final {
|
||||
PyObject* Zeros(tensorflow::TensorShape shape,
|
||||
tensorflow::DataType dtype) const final {
|
||||
PyObject* py_shape = PyTuple_New(shape.dims());
|
||||
for (int i = 0; i < shape.dims(); ++i) {
|
||||
PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
|
||||
@ -701,20 +703,20 @@ class PyVSpace : public tensorflow::eager::VSpace {
|
||||
Py_DECREF(arg_list);
|
||||
Py_DECREF(py_dtype);
|
||||
Py_DECREF(py_shape);
|
||||
return reinterpret_cast<void*>(result);
|
||||
return reinterpret_cast<PyObject*>(result);
|
||||
}
|
||||
|
||||
void* OnesLike(void* tensor) const final {
|
||||
PyObject* OnesLike(PyObject* tensor) const final {
|
||||
PyObject* arg_list = Py_BuildValue("(O)", tensor);
|
||||
PyObject* result = PyEval_CallObject(ones_like_, arg_list);
|
||||
if (result == nullptr) {
|
||||
VLOG(1) << "Call to ones_like failed";
|
||||
}
|
||||
Py_DECREF(arg_list);
|
||||
return reinterpret_cast<void*>(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
tensorflow::int64 TensorId(void* tensor) const final {
|
||||
tensorflow::int64 TensorId(PyObject* tensor) const final {
|
||||
PyObject* py_tensor = reinterpret_cast<PyObject*>(tensor);
|
||||
PyObject* id_field = PyObject_GetAttrString(py_tensor, "_id");
|
||||
tensorflow::int64 id = MakeInt(id_field);
|
||||
@ -723,9 +725,9 @@ class PyVSpace : public tensorflow::eager::VSpace {
|
||||
}
|
||||
|
||||
tensorflow::Status CallBackwardFunction(
|
||||
void* backward_function,
|
||||
tensorflow::gtl::ArraySlice<void*> output_gradients,
|
||||
std::vector<void*>* result) const final {
|
||||
PyObject* backward_function,
|
||||
tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
|
||||
std::vector<PyObject*>* result) const final {
|
||||
PyObject* grads = PyTuple_New(output_gradients.size());
|
||||
for (int i = 0; i < output_gradients.size(); ++i) {
|
||||
if (output_gradients[i] == nullptr) {
|
||||
@ -771,9 +773,7 @@ class PyVSpace : public tensorflow::eager::VSpace {
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void DeleteTensor(void* tensor) const final {
|
||||
Py_XDECREF(reinterpret_cast<PyObject*>(tensor));
|
||||
}
|
||||
void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
|
||||
|
||||
private:
|
||||
PyObject* py_vspace_;
|
||||
@ -784,13 +784,13 @@ class PyVSpace : public tensorflow::eager::VSpace {
|
||||
PyObject* ones_like_;
|
||||
};
|
||||
|
||||
std::vector<void*> MakeTensorList(PyObject* tensors) {
|
||||
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
|
||||
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
|
||||
if (seq == nullptr) {
|
||||
return {};
|
||||
}
|
||||
int len = PySequence_Fast_GET_SIZE(seq);
|
||||
std::vector<void*> list;
|
||||
std::vector<PyObject*> list;
|
||||
list.reserve(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
list.push_back(PySequence_Fast_GET_ITEM(seq, i));
|
||||
@ -807,30 +807,30 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<void*> target_vec = MakeTensorList(target);
|
||||
std::vector<PyObject*> target_vec = MakeTensorList(target);
|
||||
if (PyErr_Occurred()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<void*> sources_vec = MakeTensorList(sources);
|
||||
std::vector<PyObject*> sources_vec = MakeTensorList(sources);
|
||||
if (PyErr_Occurred()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<void*> outgrad_vec;
|
||||
std::vector<PyObject*> outgrad_vec;
|
||||
if (output_gradients != Py_None) {
|
||||
outgrad_vec = MakeTensorList(output_gradients);
|
||||
if (PyErr_Occurred()) {
|
||||
return nullptr;
|
||||
}
|
||||
for (void* tensor : outgrad_vec) {
|
||||
for (PyObject* tensor : outgrad_vec) {
|
||||
// Calling the backward function will eat a reference to the tensors in
|
||||
// outgrad_vec, so we need to increase their reference count.
|
||||
Py_INCREF(reinterpret_cast<PyObject*>(tensor));
|
||||
Py_INCREF(tensor);
|
||||
}
|
||||
}
|
||||
TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
|
||||
std::vector<void*> result;
|
||||
status->status = tape_obj->tape->Gradient(c_vspace, target_vec, sources_vec,
|
||||
outgrad_vec, &result);
|
||||
std::vector<PyObject*> result;
|
||||
status->status = tape_obj->tape->ComputeGradient(
|
||||
c_vspace, target_vec, sources_vec, outgrad_vec, &result);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user