Keep a constant reference to VSpace in ForwardAccumulator instead of a constant pointer

Also checks that one was already registered instead of segfaulting.

PiperOrigin-RevId: 253879165
This commit is contained in:
Allen Lavoie 2019-06-18 15:19:37 -07:00 committed by TensorFlower Gardener
parent 26b14a84d3
commit e985e958b3
2 changed files with 24 additions and 18 deletions

View File

@ -193,12 +193,12 @@ class ForwardAccumulator {
// Does not take ownership of `vspace`, which must outlive the // Does not take ownership of `vspace`, which must outlive the
// ForwardAccumulator. // ForwardAccumulator.
explicit ForwardAccumulator( explicit ForwardAccumulator(
const VSpace<Gradient, BackwardFunction, TapeTensor>* vspace) const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {} : vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
virtual ~ForwardAccumulator() { virtual ~ForwardAccumulator() {
for (auto accumulated : accumulated_gradients_) { for (auto accumulated : accumulated_gradients_) {
vspace_->DeleteGradient(accumulated.second); vspace_.DeleteGradient(accumulated.second);
} }
} }
@ -276,7 +276,7 @@ class ForwardAccumulator {
std::unordered_map<int64, Gradient*> accumulated_gradients_; std::unordered_map<int64, Gradient*> accumulated_gradients_;
// Not owned; provides operations on Tensors which are currently only // Not owned; provides operations on Tensors which are currently only
// available in language bindings (e.g. Python). // available in language bindings (e.g. Python).
const VSpace<Gradient, BackwardFunction, TapeTensor>* vspace_; const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
// Set temporarily while in the Accumulate method; if backward_tape_ is not // Set temporarily while in the Accumulate method; if backward_tape_ is not
// nullptr then we forward op executions to it so Accumulate can compute a // nullptr then we forward op executions to it so Accumulate can compute a
// backward pass on its backward function. // backward pass on its backward function.
@ -865,9 +865,9 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
sources.reserve(output_tensors.size()); sources.reserve(output_tensors.size());
for (const TapeTensor& output_tensor : output_tensors) { for (const TapeTensor& output_tensor : output_tensors) {
// Ownership of `aid` transferred to CallBackwardFunction below. // Ownership of `aid` transferred to CallBackwardFunction below.
Gradient* aid = vspace_->Ones(output_tensor); Gradient* aid = vspace_.Ones(output_tensor);
forwardprop_aids.push_back(aid); forwardprop_aids.push_back(aid);
int64 aid_id = vspace_->TensorId(aid); int64 aid_id = vspace_.TensorId(aid);
sources.push_back(aid_id); sources.push_back(aid_id);
sources_set.insert(aid_id); sources_set.insert(aid_id);
tape->Watch(aid_id); tape->Watch(aid_id);
@ -875,7 +875,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
std::vector<Gradient*> grad; std::vector<Gradient*> grad;
auto delete_grad = gtl::MakeCleanup([&grad, this] { auto delete_grad = gtl::MakeCleanup([&grad, this] {
for (Gradient* tensor : grad) { for (Gradient* tensor : grad) {
this->vspace_->DeleteGradient(tensor); this->vspace_.DeleteGradient(tensor);
} }
}); });
{ {
@ -883,7 +883,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>> std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
backward_function(backward_function_getter(), backward_function(backward_function_getter(),
backward_function_deleter); backward_function_deleter);
TF_RETURN_IF_ERROR(vspace_->CallBackwardFunction( TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
backward_function.get(), unneeded_gradients, forwardprop_aids, &grad)); backward_function.get(), unneeded_gradients, forwardprop_aids, &grad));
} }
@ -894,11 +894,11 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
std::unordered_map<int64, TapeTensor> sources_that_are_targets; std::unordered_map<int64, TapeTensor> sources_that_are_targets;
for (Gradient* grad_tensor : grad) { for (Gradient* grad_tensor : grad) {
if (grad_tensor != nullptr) { if (grad_tensor != nullptr) {
int64 tensor_id = vspace_->TensorId(grad_tensor); int64 tensor_id = vspace_.TensorId(grad_tensor);
targets.push_back(tensor_id); targets.push_back(tensor_id);
if (sources_set.find(tensor_id) != sources_set.end()) { if (sources_set.find(tensor_id) != sources_set.end()) {
sources_that_are_targets.emplace( sources_that_are_targets.emplace(
tensor_id, vspace_->TapeTensorFromGradient(grad_tensor)); tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
} }
} }
} }
@ -911,11 +911,11 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
Gradient* grad_tensor = grad[target_index]; Gradient* grad_tensor = grad[target_index];
if (grad_tensor != nullptr && in_grad != nullptr) { if (grad_tensor != nullptr && in_grad != nullptr) {
// ComputeGradient steals a reference // ComputeGradient steals a reference
vspace_->MarkAsResult(in_grad); vspace_.MarkAsResult(in_grad);
} }
} }
return tape->ComputeGradient(*vspace_, targets, sources, return tape->ComputeGradient(vspace_, targets, sources,
sources_that_are_targets, in_grads, out_grads); sources_that_are_targets, in_grads, out_grads);
} }
@ -952,7 +952,7 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
std::vector<Gradient*> new_zeros; std::vector<Gradient*> new_zeros;
auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] { auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
for (Gradient* tensor : new_zeros) { for (Gradient* tensor : new_zeros) {
this->vspace_->DeleteGradient(tensor); this->vspace_.DeleteGradient(tensor);
} }
}); });
std::vector<Gradient*> in_grads; std::vector<Gradient*> in_grads;
@ -965,7 +965,7 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
if (IsDtypeTrainable(input_tensors[target_index].GetDType())) { if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
// ForwardAccumulator defaults to zeros for unwatched Tensors, unlike // ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
// GradientTape which uses ones. // GradientTape which uses ones.
Gradient* zero = vspace_->Zeros(input_tensors[target_index]); Gradient* zero = vspace_.Zeros(input_tensors[target_index]);
new_zeros.push_back(zero); new_zeros.push_back(zero);
in_grads.push_back(zero); in_grads.push_back(zero);
} else { } else {
@ -990,7 +990,7 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
int64 tensor_id = output_tensors[i].GetID(); int64 tensor_id = output_tensors[i].GetID();
auto existing = accumulated_gradients_.find(tensor_id); auto existing = accumulated_gradients_.find(tensor_id);
if (existing != accumulated_gradients_.end()) { if (existing != accumulated_gradients_.end()) {
vspace_->DeleteGradient(existing->second); vspace_.DeleteGradient(existing->second);
} }
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i]; accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
} }
@ -1002,14 +1002,14 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch( void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
int64 tensor_id, Gradient* tangent) { int64 tensor_id, Gradient* tangent) {
auto existing = accumulated_gradients_.find(tensor_id); auto existing = accumulated_gradients_.find(tensor_id);
vspace_->MarkAsResult(tangent); vspace_.MarkAsResult(tangent);
if (existing == accumulated_gradients_.end()) { if (existing == accumulated_gradients_.end()) {
accumulated_gradients_.emplace(tensor_id, tangent); accumulated_gradients_.emplace(tensor_id, tangent);
} else { } else {
std::array<Gradient*, 2> to_aggregate({tangent, existing->second}); std::array<Gradient*, 2> to_aggregate({tangent, existing->second});
// AggregateGradients steals a reference to each of its arguments. We // AggregateGradients steals a reference to each of its arguments. We
// MarkAsResult on `tangent` above so we don't steal a reference to it. // MarkAsResult on `tangent` above so we don't steal a reference to it.
existing->second = vspace_->AggregateGradients(to_aggregate); existing->second = vspace_.AggregateGradients(to_aggregate);
} }
} }
@ -1018,7 +1018,7 @@ void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient(
int64 tensor_id) { int64 tensor_id) {
auto existing = accumulated_gradients_.find(tensor_id); auto existing = accumulated_gradients_.find(tensor_id);
if (existing != accumulated_gradients_.end()) { if (existing != accumulated_gradients_.end()) {
vspace_->DeleteGradient(existing->second); vspace_.DeleteGradient(existing->second);
accumulated_gradients_.erase(existing); accumulated_gradients_.erase(existing);
} }
} }

View File

@ -1947,7 +1947,13 @@ PyObject* TFE_Py_ForwardAccumulatorNew() {
if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr; if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
TFE_Py_ForwardAccumulator* accumulator = TFE_Py_ForwardAccumulator* accumulator =
PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type); PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
accumulator->accumulator = new ForwardAccumulator(py_vspace); if (py_vspace == nullptr) {
MaybeRaiseExceptionFromStatus(
tensorflow::errors::Internal(
"ForwardAccumulator requires a PyVSpace to be registered."),
nullptr);
}
accumulator->accumulator = new ForwardAccumulator(*py_vspace);
Py_INCREF(accumulator); Py_INCREF(accumulator);
GetAccumulatorSet()->insert( GetAccumulatorSet()->insert(
reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));