Skip calling back into python if only 1 gradient to aggregate

PiperOrigin-RevId: 203786896
This commit is contained in:
Akshay Modi 2018-07-09 10:23:16 -07:00 committed by TensorFlower Gardener
parent 3b7edb3ced
commit 6732ec3dff

View File

@ -520,7 +520,12 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
}
} else {
any_gradient_nonzero = true;
auto new_gradients = vspace.AggregateGradients(grad_it->second);
Gradient* new_gradients = nullptr;
if (grad_it->second.size() == 1) {
new_gradients = grad_it->second.at(0);
} else {
new_gradients = vspace.AggregateGradients(grad_it->second);
}
if (sources_set.find(grad_it->first) == sources_set.end()) {
gradients.erase(grad_it);
} else {