Fixed usage calculation for input buffers.

PiperOrigin-RevId: 299376328
Change-Id: Idcc28a7fdf827ce17bbf46168a56424a0a835933
This commit is contained in:
Raman Sarokin 2020-03-06 09:45:25 -08:00 committed by TensorFlower Gardener
parent 4e5be62116
commit 8685cf97d5

View File

@ -400,6 +400,12 @@ void InferenceContext::Merge() {
void InferenceContext::GetUsages(
const std::function<bool(const TensorDescriptor&)>& functor,
std::map<ValueId, int2>* usages) {
for (ValueId in_id : input_ids_) {
const auto& desc = tensor_reserver_.Get(in_id).descriptor;
if (functor(desc)) {
AddUsage(in_id, 0, usages);
}
}
for (int op_index = 0; op_index < nodes_.size(); ++op_index) {
auto tensors = GetCLNodeTensors(nodes_[op_index]);
for (auto& tensor : tensors) {
@ -408,7 +414,7 @@ void InferenceContext::GetUsages(
}
}
}
for (auto& out_id : output_ids_) {
for (ValueId out_id : output_ids_) {
const auto& desc = tensor_reserver_.Get(out_id).descriptor;
if (functor(desc)) {
AddUsage(out_id, nodes_.size(), usages);