Keep a constant reference to VSpace in ForwardAccumulator instead of a constant pointer
Also checks that one was already registered instead of segfaulting. PiperOrigin-RevId: 253879165
This commit is contained in:
parent
26b14a84d3
commit
e985e958b3
@ -193,12 +193,12 @@ class ForwardAccumulator {
|
|||||||
// Does not take ownership of `vspace`, which must outlive the
|
// Does not take ownership of `vspace`, which must outlive the
|
||||||
// ForwardAccumulator.
|
// ForwardAccumulator.
|
||||||
explicit ForwardAccumulator(
|
explicit ForwardAccumulator(
|
||||||
const VSpace<Gradient, BackwardFunction, TapeTensor>* vspace)
|
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
|
||||||
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
|
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
|
||||||
|
|
||||||
virtual ~ForwardAccumulator() {
|
virtual ~ForwardAccumulator() {
|
||||||
for (auto accumulated : accumulated_gradients_) {
|
for (auto accumulated : accumulated_gradients_) {
|
||||||
vspace_->DeleteGradient(accumulated.second);
|
vspace_.DeleteGradient(accumulated.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -276,7 +276,7 @@ class ForwardAccumulator {
|
|||||||
std::unordered_map<int64, Gradient*> accumulated_gradients_;
|
std::unordered_map<int64, Gradient*> accumulated_gradients_;
|
||||||
// Not owned; provides operations on Tensors which are currently only
|
// Not owned; provides operations on Tensors which are currently only
|
||||||
// available in language bindings (e.g. Python).
|
// available in language bindings (e.g. Python).
|
||||||
const VSpace<Gradient, BackwardFunction, TapeTensor>* vspace_;
|
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
|
||||||
// Set temporarily while in the Accumulate method; if backward_tape_ is not
|
// Set temporarily while in the Accumulate method; if backward_tape_ is not
|
||||||
// nullptr then we forward op executions to it so Accumulate can compute a
|
// nullptr then we forward op executions to it so Accumulate can compute a
|
||||||
// backward pass on its backward function.
|
// backward pass on its backward function.
|
||||||
@ -865,9 +865,9 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
|||||||
sources.reserve(output_tensors.size());
|
sources.reserve(output_tensors.size());
|
||||||
for (const TapeTensor& output_tensor : output_tensors) {
|
for (const TapeTensor& output_tensor : output_tensors) {
|
||||||
// Ownership of `aid` transferred to CallBackwardFunction below.
|
// Ownership of `aid` transferred to CallBackwardFunction below.
|
||||||
Gradient* aid = vspace_->Ones(output_tensor);
|
Gradient* aid = vspace_.Ones(output_tensor);
|
||||||
forwardprop_aids.push_back(aid);
|
forwardprop_aids.push_back(aid);
|
||||||
int64 aid_id = vspace_->TensorId(aid);
|
int64 aid_id = vspace_.TensorId(aid);
|
||||||
sources.push_back(aid_id);
|
sources.push_back(aid_id);
|
||||||
sources_set.insert(aid_id);
|
sources_set.insert(aid_id);
|
||||||
tape->Watch(aid_id);
|
tape->Watch(aid_id);
|
||||||
@ -875,7 +875,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
|||||||
std::vector<Gradient*> grad;
|
std::vector<Gradient*> grad;
|
||||||
auto delete_grad = gtl::MakeCleanup([&grad, this] {
|
auto delete_grad = gtl::MakeCleanup([&grad, this] {
|
||||||
for (Gradient* tensor : grad) {
|
for (Gradient* tensor : grad) {
|
||||||
this->vspace_->DeleteGradient(tensor);
|
this->vspace_.DeleteGradient(tensor);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
{
|
{
|
||||||
@ -883,7 +883,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
|||||||
std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
|
std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
|
||||||
backward_function(backward_function_getter(),
|
backward_function(backward_function_getter(),
|
||||||
backward_function_deleter);
|
backward_function_deleter);
|
||||||
TF_RETURN_IF_ERROR(vspace_->CallBackwardFunction(
|
TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
|
||||||
backward_function.get(), unneeded_gradients, forwardprop_aids, &grad));
|
backward_function.get(), unneeded_gradients, forwardprop_aids, &grad));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -894,11 +894,11 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
|||||||
std::unordered_map<int64, TapeTensor> sources_that_are_targets;
|
std::unordered_map<int64, TapeTensor> sources_that_are_targets;
|
||||||
for (Gradient* grad_tensor : grad) {
|
for (Gradient* grad_tensor : grad) {
|
||||||
if (grad_tensor != nullptr) {
|
if (grad_tensor != nullptr) {
|
||||||
int64 tensor_id = vspace_->TensorId(grad_tensor);
|
int64 tensor_id = vspace_.TensorId(grad_tensor);
|
||||||
targets.push_back(tensor_id);
|
targets.push_back(tensor_id);
|
||||||
if (sources_set.find(tensor_id) != sources_set.end()) {
|
if (sources_set.find(tensor_id) != sources_set.end()) {
|
||||||
sources_that_are_targets.emplace(
|
sources_that_are_targets.emplace(
|
||||||
tensor_id, vspace_->TapeTensorFromGradient(grad_tensor));
|
tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -911,11 +911,11 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
|||||||
Gradient* grad_tensor = grad[target_index];
|
Gradient* grad_tensor = grad[target_index];
|
||||||
if (grad_tensor != nullptr && in_grad != nullptr) {
|
if (grad_tensor != nullptr && in_grad != nullptr) {
|
||||||
// ComputeGradient steals a reference
|
// ComputeGradient steals a reference
|
||||||
vspace_->MarkAsResult(in_grad);
|
vspace_.MarkAsResult(in_grad);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tape->ComputeGradient(*vspace_, targets, sources,
|
return tape->ComputeGradient(vspace_, targets, sources,
|
||||||
sources_that_are_targets, in_grads, out_grads);
|
sources_that_are_targets, in_grads, out_grads);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -952,7 +952,7 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
|||||||
std::vector<Gradient*> new_zeros;
|
std::vector<Gradient*> new_zeros;
|
||||||
auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
|
auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
|
||||||
for (Gradient* tensor : new_zeros) {
|
for (Gradient* tensor : new_zeros) {
|
||||||
this->vspace_->DeleteGradient(tensor);
|
this->vspace_.DeleteGradient(tensor);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
std::vector<Gradient*> in_grads;
|
std::vector<Gradient*> in_grads;
|
||||||
@ -965,7 +965,7 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
|||||||
if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
|
if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
|
||||||
// ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
|
// ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
|
||||||
// GradientTape which uses ones.
|
// GradientTape which uses ones.
|
||||||
Gradient* zero = vspace_->Zeros(input_tensors[target_index]);
|
Gradient* zero = vspace_.Zeros(input_tensors[target_index]);
|
||||||
new_zeros.push_back(zero);
|
new_zeros.push_back(zero);
|
||||||
in_grads.push_back(zero);
|
in_grads.push_back(zero);
|
||||||
} else {
|
} else {
|
||||||
@ -990,7 +990,7 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
|||||||
int64 tensor_id = output_tensors[i].GetID();
|
int64 tensor_id = output_tensors[i].GetID();
|
||||||
auto existing = accumulated_gradients_.find(tensor_id);
|
auto existing = accumulated_gradients_.find(tensor_id);
|
||||||
if (existing != accumulated_gradients_.end()) {
|
if (existing != accumulated_gradients_.end()) {
|
||||||
vspace_->DeleteGradient(existing->second);
|
vspace_.DeleteGradient(existing->second);
|
||||||
}
|
}
|
||||||
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
|
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
|
||||||
}
|
}
|
||||||
@ -1002,14 +1002,14 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
|||||||
void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
|
void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
|
||||||
int64 tensor_id, Gradient* tangent) {
|
int64 tensor_id, Gradient* tangent) {
|
||||||
auto existing = accumulated_gradients_.find(tensor_id);
|
auto existing = accumulated_gradients_.find(tensor_id);
|
||||||
vspace_->MarkAsResult(tangent);
|
vspace_.MarkAsResult(tangent);
|
||||||
if (existing == accumulated_gradients_.end()) {
|
if (existing == accumulated_gradients_.end()) {
|
||||||
accumulated_gradients_.emplace(tensor_id, tangent);
|
accumulated_gradients_.emplace(tensor_id, tangent);
|
||||||
} else {
|
} else {
|
||||||
std::array<Gradient*, 2> to_aggregate({tangent, existing->second});
|
std::array<Gradient*, 2> to_aggregate({tangent, existing->second});
|
||||||
// AggregateGradients steals a reference to each of its arguments. We
|
// AggregateGradients steals a reference to each of its arguments. We
|
||||||
// MarkAsResult on `tangent` above so we don't steal a reference to it.
|
// MarkAsResult on `tangent` above so we don't steal a reference to it.
|
||||||
existing->second = vspace_->AggregateGradients(to_aggregate);
|
existing->second = vspace_.AggregateGradients(to_aggregate);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1018,7 +1018,7 @@ void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient(
|
|||||||
int64 tensor_id) {
|
int64 tensor_id) {
|
||||||
auto existing = accumulated_gradients_.find(tensor_id);
|
auto existing = accumulated_gradients_.find(tensor_id);
|
||||||
if (existing != accumulated_gradients_.end()) {
|
if (existing != accumulated_gradients_.end()) {
|
||||||
vspace_->DeleteGradient(existing->second);
|
vspace_.DeleteGradient(existing->second);
|
||||||
accumulated_gradients_.erase(existing);
|
accumulated_gradients_.erase(existing);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1947,7 +1947,13 @@ PyObject* TFE_Py_ForwardAccumulatorNew() {
|
|||||||
if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
|
if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
|
||||||
TFE_Py_ForwardAccumulator* accumulator =
|
TFE_Py_ForwardAccumulator* accumulator =
|
||||||
PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
|
PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
|
||||||
accumulator->accumulator = new ForwardAccumulator(py_vspace);
|
if (py_vspace == nullptr) {
|
||||||
|
MaybeRaiseExceptionFromStatus(
|
||||||
|
tensorflow::errors::Internal(
|
||||||
|
"ForwardAccumulator requires a PyVSpace to be registered."),
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
accumulator->accumulator = new ForwardAccumulator(*py_vspace);
|
||||||
Py_INCREF(accumulator);
|
Py_INCREF(accumulator);
|
||||||
GetAccumulatorSet()->insert(
|
GetAccumulatorSet()->insert(
|
||||||
reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
|
reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
|
||||||
|
Loading…
Reference in New Issue
Block a user