[3] Review comments handled

This commit is contained in:
ANSHUMAN TRIPATHY 2019-03-29 11:53:22 +05:30
parent 6410f05c96
commit 79122e21c8

View File

@ -32,6 +32,12 @@ namespace ops {
namespace builtin { namespace builtin {
namespace concatenation { namespace concatenation {
// This file has two implementation of Concatenation.
enum KernelType {
kReference,
kGenericOptimized,
};
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data); reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
@ -94,6 +100,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size); return context->ResizeTensor(context, output, output_size);
} }
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data); reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
@ -105,47 +112,59 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// allocate and populate these during Prepare(). // allocate and populate these during Prepare().
// TODO(ycling): Activation function parameter is ignored. For now we dont have // TODO(ycling): Activation function parameter is ignored. For now we dont have
// a model with a Concatenation with fused activation function. // a model with a Concatenation with fused activation function.
#define TF_LITE_CONCATENATION(type, scalar) \ #define TF_LITE_CONCATENATION(scalar) \
{ \
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
tflite::ConcatenationParams op_params; \
op_params.axis = axis; \
op_params.inputs_count = node->inputs->size; \
type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), \
GetTensorData<scalar>(output)); \
}
#define TF_LITE_CONCATENATION_QUANTIZED(type) \
{ \ { \
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \ VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
tflite::ConcatenationParams op_params; \ tflite::ConcatenationParams op_params; \
op_params.axis = axis; \ op_params.axis = axis; \
op_params.input_zeropoint = all_inputs.zero_point(); \
op_params.input_scale = all_inputs.scale(); \
op_params.inputs_count = node->inputs->size; \ op_params.inputs_count = node->inputs->size; \
op_params.output_zeropoint = output->params.zero_point; \ if (kernel_type == kReference) { \
op_params.output_scale = output->params.scale; \ reference_ops::Concatenation(op_params, all_inputs.shapes(), \
type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
all_inputs.data(), GetTensorShape(output), \ all_inputs.data(), GetTensorShape(output), \
GetTensorData<uint8>(output)); \ GetTensorData<scalar>(output)); \
} else { \
optimized_ops::Concatenation(op_params, all_inputs.shapes(), \
all_inputs.data(), GetTensorShape(output), \
GetTensorData<scalar>(output)); \
} \
}
#define TF_LITE_CONCATENATION_QUANTIZED \
{ \
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
tflite::ConcatenationParams op_params; \
op_params.axis = axis; \
op_params.input_zeropoint = all_inputs.zero_point(); \
op_params.input_scale = all_inputs.scale(); \
op_params.inputs_count = node->inputs->size; \
op_params.output_zeropoint = output->params.zero_point; \
op_params.output_scale = output->params.scale; \
if (kernel_type == kReference) { \
reference_ops::ConcatenationWithScaling( \
op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), GetTensorData<uint8>(output)); \
} else { \
optimized_ops::ConcatenationWithScaling( \
op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), GetTensorData<uint8>(output)); \
} \
} }
switch (output->type) { // Already know in/outtypes are same. switch (output->type) { // Already know in/outtypes are same.
case kTfLiteFloat32: case kTfLiteFloat32:
TF_LITE_CONCATENATION(reference_ops, float); TF_LITE_CONCATENATION(float);
break; break;
case kTfLiteInt32: case kTfLiteInt32:
TF_LITE_CONCATENATION(reference_ops, int32); TF_LITE_CONCATENATION(int32);
break; break;
case kTfLiteUInt8: case kTfLiteUInt8:
TF_LITE_CONCATENATION_QUANTIZED(reference_ops); TF_LITE_CONCATENATION_QUANTIZED;
break; break;
case kTfLiteInt8: case kTfLiteInt8:
TF_LITE_CONCATENATION(reference_ops, int8_t); TF_LITE_CONCATENATION(int8_t);
break; break;
case kTfLiteInt64: case kTfLiteInt64:
TF_LITE_CONCATENATION(reference_ops, int64_t); TF_LITE_CONCATENATION(int64_t);
break; break;
default: default:
@ -163,13 +182,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace concatenation } // namespace concatenation
TfLiteRegistration* Register_CONCATENATION_REF() { TfLiteRegistration* Register_CONCATENATION_REF() {
static TfLiteRegistration r = {nullptr, nullptr, concatenation::Prepare, static TfLiteRegistration r = {
concatenation::Eval}; nullptr, nullptr, concatenation::Prepare,
concatenation::Eval<concatenation::kReference>};
return &r;
}
TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
static TfLiteRegistration r = {
nullptr, nullptr, concatenation::Prepare,
concatenation::Eval<concatenation::kGenericOptimized>};
return &r; return &r;
} }
TfLiteRegistration* Register_CONCATENATION() { TfLiteRegistration* Register_CONCATENATION() {
return Register_CONCATENATION_REF(); // TODO(ahentz): It turns out the two versions of Concatenation are almost
// identical, so we should consider removing one.
return Register_CONCATENATION_GENERIC_OPT();
} }
} // namespace builtin } // namespace builtin