slightly optimize tflite Concatenate operation

This commit is contained in:
Satoshi Tanaka 2019-09-19 10:22:26 +09:00
parent 74c5253184
commit 56dbacebc9

View File

@ -1147,13 +1147,18 @@ inline void Concatenation(const ConcatenationParams& params,
base_inner_size *= output_shape.Dims(i);
}
std::vector<int> copy_sizes;
std::vector<Scalar*> input_ptrs;
for (int i = 0; i < inputs_count; ++i) {
copy_sizes.push_back(input_shapes[i]->Dims(axis) * base_inner_size);
input_ptrs.push_back(const_cast<Scalar*>(input_data[i]));
}
Scalar* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size;
memcpy(output_ptr, input_data[i] + k * copy_size,
copy_size * sizeof(Scalar));
output_ptr += copy_size;
memcpy(output_ptr, input_ptrs[i], copy_sizes[i] * sizeof(Scalar));
output_ptr += copy_sizes[i];
input_ptrs[i] += copy_sizes[i];
}
}
}