Skip calling back into python if only 1 gradient to aggregate
PiperOrigin-RevId: 203786896
This commit is contained in:
parent
3b7edb3ced
commit
6732ec3dff
@ -520,7 +520,12 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
||||
}
|
||||
} else {
|
||||
any_gradient_nonzero = true;
|
||||
auto new_gradients = vspace.AggregateGradients(grad_it->second);
|
||||
Gradient* new_gradients = nullptr;
|
||||
if (grad_it->second.size() == 1) {
|
||||
new_gradients = grad_it->second.at(0);
|
||||
} else {
|
||||
new_gradients = vspace.AggregateGradients(grad_it->second);
|
||||
}
|
||||
if (sources_set.find(grad_it->first) == sources_set.end()) {
|
||||
gradients.erase(grad_it);
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user