Merge pull request #32639 from stnk20:optimize-tflite-concatenation-op

PiperOrigin-RevId: 271579103
This commit is contained in:
TensorFlower Gardener 2019-09-27 10:44:48 -07:00
commit defb20e749

View File

@ -1147,13 +1147,18 @@ inline void Concatenation(const ConcatenationParams& params,
base_inner_size *= output_shape.Dims(i); base_inner_size *= output_shape.Dims(i);
} }
std::vector<int> copy_sizes;
std::vector<Scalar*> input_ptrs;
for (int i = 0; i < inputs_count; ++i) {
copy_sizes.push_back(input_shapes[i]->Dims(axis) * base_inner_size);
input_ptrs.push_back(const_cast<Scalar*>(input_data[i]));
}
Scalar* output_ptr = output_data; Scalar* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) { for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) { for (int i = 0; i < inputs_count; ++i) {
const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size; memcpy(output_ptr, input_ptrs[i], copy_sizes[i] * sizeof(Scalar));
memcpy(output_ptr, input_data[i] + k * copy_size, output_ptr += copy_sizes[i];
copy_size * sizeof(Scalar)); input_ptrs[i] += copy_sizes[i];
output_ptr += copy_size;
} }
} }
} }