From 79122e21c808d1afca6be3e6c6e84fa7f2da9820 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Fri, 29 Mar 2019 11:53:22 +0530 Subject: [PATCH] [3] Review comments handled --- tensorflow/lite/kernels/concatenation.cc | 83 ++++++++++++++++-------- 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index ac118c7c514..4e5a2d956d1 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -32,6 +32,12 @@ namespace ops { namespace builtin { namespace concatenation { +// This file has two implementation of Concatenation. +enum KernelType { + kReference, + kGenericOptimized, +}; + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -94,6 +100,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } +template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -105,47 +112,59 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // allocate and populate these during Prepare(). // TODO(ycling): Activation function parameter is ignored. For now we dont have // a model with a Concatenation with fused activation function. -#define TF_LITE_CONCATENATION(type, scalar) \ - { \ - VectorOfTensors all_inputs(*context, *node->inputs); \ - tflite::ConcatenationParams op_params; \ - op_params.axis = axis; \ - op_params.inputs_count = node->inputs->size; \ - type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \ - GetTensorShape(output), \ - GetTensorData(output)); \ - } - -#define TF_LITE_CONCATENATION_QUANTIZED(type) \ +#define TF_LITE_CONCATENATION(scalar) \ { \ - VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \ + VectorOfTensors all_inputs(*context, *node->inputs); \ tflite::ConcatenationParams op_params; \ op_params.axis = axis; \ - op_params.input_zeropoint = all_inputs.zero_point(); \ - op_params.input_scale = all_inputs.scale(); \ op_params.inputs_count = node->inputs->size; \ - op_params.output_zeropoint = output->params.zero_point; \ - op_params.output_scale = output->params.scale; \ - type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \ + if (kernel_type == kReference) { \ + reference_ops::Concatenation(op_params, all_inputs.shapes(), \ all_inputs.data(), GetTensorShape(output), \ - GetTensorData(output)); \ + GetTensorData(output)); \ + } else { \ + optimized_ops::Concatenation(op_params, all_inputs.shapes(), \ + all_inputs.data(), GetTensorShape(output), \ + GetTensorData(output)); \ + } \ + } + +#define TF_LITE_CONCATENATION_QUANTIZED \ + { \ + VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \ + tflite::ConcatenationParams op_params; \ + op_params.axis = axis; \ + op_params.input_zeropoint = all_inputs.zero_point(); \ + op_params.input_scale = all_inputs.scale(); \ + op_params.inputs_count = node->inputs->size; \ + op_params.output_zeropoint = output->params.zero_point; \ + op_params.output_scale = output->params.scale; \ + if (kernel_type == kReference) { \ + reference_ops::ConcatenationWithScaling( \ + op_params, all_inputs.shapes(), all_inputs.data(), \ + GetTensorShape(output), GetTensorData(output)); \ + } else { \ + optimized_ops::ConcatenationWithScaling( \ + op_params, all_inputs.shapes(), all_inputs.data(), \ + GetTensorShape(output), GetTensorData(output)); \ + } \ } switch (output->type) { // Already know in/outtypes are same. case kTfLiteFloat32: - TF_LITE_CONCATENATION(reference_ops, float); + TF_LITE_CONCATENATION(float); break; case kTfLiteInt32: - TF_LITE_CONCATENATION(reference_ops, int32); + TF_LITE_CONCATENATION(int32); break; case kTfLiteUInt8: - TF_LITE_CONCATENATION_QUANTIZED(reference_ops); + TF_LITE_CONCATENATION_QUANTIZED; break; case kTfLiteInt8: - TF_LITE_CONCATENATION(reference_ops, int8_t); + TF_LITE_CONCATENATION(int8_t); break; case kTfLiteInt64: - TF_LITE_CONCATENATION(reference_ops, int64_t); + TF_LITE_CONCATENATION(int64_t); break; default: @@ -163,13 +182,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace concatenation TfLiteRegistration* Register_CONCATENATION_REF() { - static TfLiteRegistration r = {nullptr, nullptr, concatenation::Prepare, - concatenation::Eval}; + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; return &r; } TfLiteRegistration* Register_CONCATENATION() { - return Register_CONCATENATION_REF(); + // TODO(ahentz): It turns out the two versions of Concatenation are almost + // identical, so we should consider removing one. + return Register_CONCATENATION_GENERIC_OPT(); } } // namespace builtin