[3] Review comments handled
This commit is contained in:
parent
6410f05c96
commit
79122e21c8
@ -32,6 +32,12 @@ namespace ops {
|
|||||||
namespace builtin {
|
namespace builtin {
|
||||||
namespace concatenation {
|
namespace concatenation {
|
||||||
|
|
||||||
|
// This file has two implementation of Concatenation.
|
||||||
|
enum KernelType {
|
||||||
|
kReference,
|
||||||
|
kGenericOptimized,
|
||||||
|
};
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params =
|
auto* params =
|
||||||
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
|
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
|
||||||
@ -94,6 +100,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return context->ResizeTensor(context, output, output_size);
|
return context->ResizeTensor(context, output, output_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <KernelType kernel_type>
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params =
|
auto* params =
|
||||||
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
|
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
|
||||||
@ -105,18 +112,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// allocate and populate these during Prepare().
|
// allocate and populate these during Prepare().
|
||||||
// TODO(ycling): Activation function parameter is ignored. For now we dont have
|
// TODO(ycling): Activation function parameter is ignored. For now we dont have
|
||||||
// a model with a Concatenation with fused activation function.
|
// a model with a Concatenation with fused activation function.
|
||||||
#define TF_LITE_CONCATENATION(type, scalar) \
|
#define TF_LITE_CONCATENATION(scalar) \
|
||||||
{ \
|
{ \
|
||||||
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
|
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
|
||||||
tflite::ConcatenationParams op_params; \
|
tflite::ConcatenationParams op_params; \
|
||||||
op_params.axis = axis; \
|
op_params.axis = axis; \
|
||||||
op_params.inputs_count = node->inputs->size; \
|
op_params.inputs_count = node->inputs->size; \
|
||||||
type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
|
if (kernel_type == kReference) { \
|
||||||
GetTensorShape(output), \
|
reference_ops::Concatenation(op_params, all_inputs.shapes(), \
|
||||||
|
all_inputs.data(), GetTensorShape(output), \
|
||||||
GetTensorData<scalar>(output)); \
|
GetTensorData<scalar>(output)); \
|
||||||
|
} else { \
|
||||||
|
optimized_ops::Concatenation(op_params, all_inputs.shapes(), \
|
||||||
|
all_inputs.data(), GetTensorShape(output), \
|
||||||
|
GetTensorData<scalar>(output)); \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define TF_LITE_CONCATENATION_QUANTIZED(type) \
|
#define TF_LITE_CONCATENATION_QUANTIZED \
|
||||||
{ \
|
{ \
|
||||||
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
|
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
|
||||||
tflite::ConcatenationParams op_params; \
|
tflite::ConcatenationParams op_params; \
|
||||||
@ -126,26 +139,32 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
op_params.inputs_count = node->inputs->size; \
|
op_params.inputs_count = node->inputs->size; \
|
||||||
op_params.output_zeropoint = output->params.zero_point; \
|
op_params.output_zeropoint = output->params.zero_point; \
|
||||||
op_params.output_scale = output->params.scale; \
|
op_params.output_scale = output->params.scale; \
|
||||||
type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
|
if (kernel_type == kReference) { \
|
||||||
all_inputs.data(), GetTensorShape(output), \
|
reference_ops::ConcatenationWithScaling( \
|
||||||
GetTensorData<uint8>(output)); \
|
op_params, all_inputs.shapes(), all_inputs.data(), \
|
||||||
|
GetTensorShape(output), GetTensorData<uint8>(output)); \
|
||||||
|
} else { \
|
||||||
|
optimized_ops::ConcatenationWithScaling( \
|
||||||
|
op_params, all_inputs.shapes(), all_inputs.data(), \
|
||||||
|
GetTensorShape(output), GetTensorData<uint8>(output)); \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (output->type) { // Already know in/outtypes are same.
|
switch (output->type) { // Already know in/outtypes are same.
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
TF_LITE_CONCATENATION(reference_ops, float);
|
TF_LITE_CONCATENATION(float);
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
TF_LITE_CONCATENATION(reference_ops, int32);
|
TF_LITE_CONCATENATION(int32);
|
||||||
break;
|
break;
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
|
TF_LITE_CONCATENATION_QUANTIZED;
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
TF_LITE_CONCATENATION(reference_ops, int8_t);
|
TF_LITE_CONCATENATION(int8_t);
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
TF_LITE_CONCATENATION(reference_ops, int64_t);
|
TF_LITE_CONCATENATION(int64_t);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@ -163,13 +182,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace concatenation
|
} // namespace concatenation
|
||||||
|
|
||||||
TfLiteRegistration* Register_CONCATENATION_REF() {
|
TfLiteRegistration* Register_CONCATENATION_REF() {
|
||||||
static TfLiteRegistration r = {nullptr, nullptr, concatenation::Prepare,
|
static TfLiteRegistration r = {
|
||||||
concatenation::Eval};
|
nullptr, nullptr, concatenation::Prepare,
|
||||||
|
concatenation::Eval<concatenation::kReference>};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
|
||||||
|
static TfLiteRegistration r = {
|
||||||
|
nullptr, nullptr, concatenation::Prepare,
|
||||||
|
concatenation::Eval<concatenation::kGenericOptimized>};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_CONCATENATION() {
|
TfLiteRegistration* Register_CONCATENATION() {
|
||||||
return Register_CONCATENATION_REF();
|
// TODO(ahentz): It turns out the two versions of Concatenation are almost
|
||||||
|
// identical, so we should consider removing one.
|
||||||
|
return Register_CONCATENATION_GENERIC_OPT();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
|
Loading…
Reference in New Issue
Block a user