Merge pull request #32639 from stnk20:optimize-tflite-concatenation-op
PiperOrigin-RevId: 271579103
This commit is contained in:
commit
defb20e749
@ -1147,13 +1147,18 @@ inline void Concatenation(const ConcatenationParams& params,
|
|||||||
base_inner_size *= output_shape.Dims(i);
|
base_inner_size *= output_shape.Dims(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> copy_sizes;
|
||||||
|
std::vector<Scalar*> input_ptrs;
|
||||||
|
for (int i = 0; i < inputs_count; ++i) {
|
||||||
|
copy_sizes.push_back(input_shapes[i]->Dims(axis) * base_inner_size);
|
||||||
|
input_ptrs.push_back(const_cast<Scalar*>(input_data[i]));
|
||||||
|
}
|
||||||
Scalar* output_ptr = output_data;
|
Scalar* output_ptr = output_data;
|
||||||
for (int k = 0; k < outer_size; k++) {
|
for (int k = 0; k < outer_size; k++) {
|
||||||
for (int i = 0; i < inputs_count; ++i) {
|
for (int i = 0; i < inputs_count; ++i) {
|
||||||
const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size;
|
memcpy(output_ptr, input_ptrs[i], copy_sizes[i] * sizeof(Scalar));
|
||||||
memcpy(output_ptr, input_data[i] + k * copy_size,
|
output_ptr += copy_sizes[i];
|
||||||
copy_size * sizeof(Scalar));
|
input_ptrs[i] += copy_sizes[i];
|
||||||
output_ptr += copy_size;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user