[3] Review comments handled
This commit is contained in:
parent
6410f05c96
commit
79122e21c8
@ -32,6 +32,12 @@ namespace ops {
|
||||
namespace builtin {
|
||||
namespace concatenation {
|
||||
|
||||
// This file has two implementation of Concatenation.
|
||||
enum KernelType {
|
||||
kReference,
|
||||
kGenericOptimized,
|
||||
};
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
|
||||
@ -94,6 +100,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return context->ResizeTensor(context, output, output_size);
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
|
||||
@ -105,18 +112,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// allocate and populate these during Prepare().
|
||||
// TODO(ycling): Activation function parameter is ignored. For now we dont have
|
||||
// a model with a Concatenation with fused activation function.
|
||||
#define TF_LITE_CONCATENATION(type, scalar) \
|
||||
#define TF_LITE_CONCATENATION(scalar) \
|
||||
{ \
|
||||
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
|
||||
tflite::ConcatenationParams op_params; \
|
||||
op_params.axis = axis; \
|
||||
op_params.inputs_count = node->inputs->size; \
|
||||
type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
|
||||
GetTensorShape(output), \
|
||||
if (kernel_type == kReference) { \
|
||||
reference_ops::Concatenation(op_params, all_inputs.shapes(), \
|
||||
all_inputs.data(), GetTensorShape(output), \
|
||||
GetTensorData<scalar>(output)); \
|
||||
} else { \
|
||||
optimized_ops::Concatenation(op_params, all_inputs.shapes(), \
|
||||
all_inputs.data(), GetTensorShape(output), \
|
||||
GetTensorData<scalar>(output)); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define TF_LITE_CONCATENATION_QUANTIZED(type) \
|
||||
#define TF_LITE_CONCATENATION_QUANTIZED \
|
||||
{ \
|
||||
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
|
||||
tflite::ConcatenationParams op_params; \
|
||||
@ -126,26 +139,32 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
op_params.inputs_count = node->inputs->size; \
|
||||
op_params.output_zeropoint = output->params.zero_point; \
|
||||
op_params.output_scale = output->params.scale; \
|
||||
type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
|
||||
all_inputs.data(), GetTensorShape(output), \
|
||||
GetTensorData<uint8>(output)); \
|
||||
if (kernel_type == kReference) { \
|
||||
reference_ops::ConcatenationWithScaling( \
|
||||
op_params, all_inputs.shapes(), all_inputs.data(), \
|
||||
GetTensorShape(output), GetTensorData<uint8>(output)); \
|
||||
} else { \
|
||||
optimized_ops::ConcatenationWithScaling( \
|
||||
op_params, all_inputs.shapes(), all_inputs.data(), \
|
||||
GetTensorShape(output), GetTensorData<uint8>(output)); \
|
||||
} \
|
||||
}
|
||||
|
||||
switch (output->type) { // Already know in/outtypes are same.
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_CONCATENATION(reference_ops, float);
|
||||
TF_LITE_CONCATENATION(float);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_CONCATENATION(reference_ops, int32);
|
||||
TF_LITE_CONCATENATION(int32);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
|
||||
TF_LITE_CONCATENATION_QUANTIZED;
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_CONCATENATION(reference_ops, int8_t);
|
||||
TF_LITE_CONCATENATION(int8_t);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_CONCATENATION(reference_ops, int64_t);
|
||||
TF_LITE_CONCATENATION(int64_t);
|
||||
break;
|
||||
|
||||
default:
|
||||
@ -163,13 +182,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace concatenation
|
||||
|
||||
TfLiteRegistration* Register_CONCATENATION_REF() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, concatenation::Prepare,
|
||||
concatenation::Eval};
|
||||
static TfLiteRegistration r = {
|
||||
nullptr, nullptr, concatenation::Prepare,
|
||||
concatenation::Eval<concatenation::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
|
||||
static TfLiteRegistration r = {
|
||||
nullptr, nullptr, concatenation::Prepare,
|
||||
concatenation::Eval<concatenation::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_CONCATENATION() {
|
||||
return Register_CONCATENATION_REF();
|
||||
// TODO(ahentz): It turns out the two versions of Concatenation are almost
|
||||
// identical, so we should consider removing one.
|
||||
return Register_CONCATENATION_GENERIC_OPT();
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
|
Loading…
Reference in New Issue
Block a user