Keep a constant reference to VSpace in ForwardAccumulator instead of a constant pointer
Also checks that one was already registered instead of segfaulting. PiperOrigin-RevId: 253879165
This commit is contained in:
parent
26b14a84d3
commit
e985e958b3
@ -193,12 +193,12 @@ class ForwardAccumulator {
|
||||
// Does not take ownership of `vspace`, which must outlive the
|
||||
// ForwardAccumulator.
|
||||
explicit ForwardAccumulator(
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>* vspace)
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
|
||||
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
|
||||
|
||||
virtual ~ForwardAccumulator() {
|
||||
for (auto accumulated : accumulated_gradients_) {
|
||||
vspace_->DeleteGradient(accumulated.second);
|
||||
vspace_.DeleteGradient(accumulated.second);
|
||||
}
|
||||
}
|
||||
|
||||
@ -276,7 +276,7 @@ class ForwardAccumulator {
|
||||
std::unordered_map<int64, Gradient*> accumulated_gradients_;
|
||||
// Not owned; provides operations on Tensors which are currently only
|
||||
// available in language bindings (e.g. Python).
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>* vspace_;
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
|
||||
// Set temporarily while in the Accumulate method; if backward_tape_ is not
|
||||
// nullptr then we forward op executions to it so Accumulate can compute a
|
||||
// backward pass on its backward function.
|
||||
@ -865,9 +865,9 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
sources.reserve(output_tensors.size());
|
||||
for (const TapeTensor& output_tensor : output_tensors) {
|
||||
// Ownership of `aid` transferred to CallBackwardFunction below.
|
||||
Gradient* aid = vspace_->Ones(output_tensor);
|
||||
Gradient* aid = vspace_.Ones(output_tensor);
|
||||
forwardprop_aids.push_back(aid);
|
||||
int64 aid_id = vspace_->TensorId(aid);
|
||||
int64 aid_id = vspace_.TensorId(aid);
|
||||
sources.push_back(aid_id);
|
||||
sources_set.insert(aid_id);
|
||||
tape->Watch(aid_id);
|
||||
@ -875,7 +875,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
std::vector<Gradient*> grad;
|
||||
auto delete_grad = gtl::MakeCleanup([&grad, this] {
|
||||
for (Gradient* tensor : grad) {
|
||||
this->vspace_->DeleteGradient(tensor);
|
||||
this->vspace_.DeleteGradient(tensor);
|
||||
}
|
||||
});
|
||||
{
|
||||
@ -883,7 +883,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
|
||||
backward_function(backward_function_getter(),
|
||||
backward_function_deleter);
|
||||
TF_RETURN_IF_ERROR(vspace_->CallBackwardFunction(
|
||||
TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
|
||||
backward_function.get(), unneeded_gradients, forwardprop_aids, &grad));
|
||||
}
|
||||
|
||||
@ -894,11 +894,11 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
std::unordered_map<int64, TapeTensor> sources_that_are_targets;
|
||||
for (Gradient* grad_tensor : grad) {
|
||||
if (grad_tensor != nullptr) {
|
||||
int64 tensor_id = vspace_->TensorId(grad_tensor);
|
||||
int64 tensor_id = vspace_.TensorId(grad_tensor);
|
||||
targets.push_back(tensor_id);
|
||||
if (sources_set.find(tensor_id) != sources_set.end()) {
|
||||
sources_that_are_targets.emplace(
|
||||
tensor_id, vspace_->TapeTensorFromGradient(grad_tensor));
|
||||
tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -911,11 +911,11 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
Gradient* grad_tensor = grad[target_index];
|
||||
if (grad_tensor != nullptr && in_grad != nullptr) {
|
||||
// ComputeGradient steals a reference
|
||||
vspace_->MarkAsResult(in_grad);
|
||||
vspace_.MarkAsResult(in_grad);
|
||||
}
|
||||
}
|
||||
|
||||
return tape->ComputeGradient(*vspace_, targets, sources,
|
||||
return tape->ComputeGradient(vspace_, targets, sources,
|
||||
sources_that_are_targets, in_grads, out_grads);
|
||||
}
|
||||
|
||||
@ -952,7 +952,7 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
std::vector<Gradient*> new_zeros;
|
||||
auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
|
||||
for (Gradient* tensor : new_zeros) {
|
||||
this->vspace_->DeleteGradient(tensor);
|
||||
this->vspace_.DeleteGradient(tensor);
|
||||
}
|
||||
});
|
||||
std::vector<Gradient*> in_grads;
|
||||
@ -965,7 +965,7 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
|
||||
// ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
|
||||
// GradientTape which uses ones.
|
||||
Gradient* zero = vspace_->Zeros(input_tensors[target_index]);
|
||||
Gradient* zero = vspace_.Zeros(input_tensors[target_index]);
|
||||
new_zeros.push_back(zero);
|
||||
in_grads.push_back(zero);
|
||||
} else {
|
||||
@ -990,7 +990,7 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
int64 tensor_id = output_tensors[i].GetID();
|
||||
auto existing = accumulated_gradients_.find(tensor_id);
|
||||
if (existing != accumulated_gradients_.end()) {
|
||||
vspace_->DeleteGradient(existing->second);
|
||||
vspace_.DeleteGradient(existing->second);
|
||||
}
|
||||
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
|
||||
}
|
||||
@ -1002,14 +1002,14 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
|
||||
int64 tensor_id, Gradient* tangent) {
|
||||
auto existing = accumulated_gradients_.find(tensor_id);
|
||||
vspace_->MarkAsResult(tangent);
|
||||
vspace_.MarkAsResult(tangent);
|
||||
if (existing == accumulated_gradients_.end()) {
|
||||
accumulated_gradients_.emplace(tensor_id, tangent);
|
||||
} else {
|
||||
std::array<Gradient*, 2> to_aggregate({tangent, existing->second});
|
||||
// AggregateGradients steals a reference to each of its arguments. We
|
||||
// MarkAsResult on `tangent` above so we don't steal a reference to it.
|
||||
existing->second = vspace_->AggregateGradients(to_aggregate);
|
||||
existing->second = vspace_.AggregateGradients(to_aggregate);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1018,7 +1018,7 @@ void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient(
|
||||
int64 tensor_id) {
|
||||
auto existing = accumulated_gradients_.find(tensor_id);
|
||||
if (existing != accumulated_gradients_.end()) {
|
||||
vspace_->DeleteGradient(existing->second);
|
||||
vspace_.DeleteGradient(existing->second);
|
||||
accumulated_gradients_.erase(existing);
|
||||
}
|
||||
}
|
||||
|
@ -1947,7 +1947,13 @@ PyObject* TFE_Py_ForwardAccumulatorNew() {
|
||||
if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
|
||||
TFE_Py_ForwardAccumulator* accumulator =
|
||||
PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
|
||||
accumulator->accumulator = new ForwardAccumulator(py_vspace);
|
||||
if (py_vspace == nullptr) {
|
||||
MaybeRaiseExceptionFromStatus(
|
||||
tensorflow::errors::Internal(
|
||||
"ForwardAccumulator requires a PyVSpace to be registered."),
|
||||
nullptr);
|
||||
}
|
||||
accumulator->accumulator = new ForwardAccumulator(*py_vspace);
|
||||
Py_INCREF(accumulator);
|
||||
GetAccumulatorSet()->insert(
|
||||
reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
|
||||
|
Loading…
Reference in New Issue
Block a user