From 0a469be3eccdc9e617c8180dee8afdf61d8f4cc7 Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Wed, 6 Jan 2021 16:43:04 -0800 Subject: [PATCH 1/3] Refactoring fully_connected to share code between reference and optimized kernels. This change is currently for discussion and to figure out what parts of this refactor we like and what we do not. Also note that we have `#if !defined(XTENSA)` in fully_connected_common.cc because the linker appears to be failing to discard the unused symbols. We will discuss this further with the Cadence engineers but having some repro case that can be merged would likely be useful. TODO: make a github issue describing this linker behavior in more detail. Also, this refactor addresses the sign flip in fully_connected: http://b/138810107 --- .../micro/kernels/cmsis-nn/fully_connected.cc | 179 +++------------- .../lite/micro/kernels/fully_connected.cc | 194 +++--------------- .../lite/micro/kernels/fully_connected.h | 66 ++++++ .../micro/kernels/fully_connected_common.cc | 141 +++++++++++++ .../micro/kernels/xtensa/fully_connected.cc | 97 +++------ tensorflow/lite/micro/tools/make/Makefile | 1 + 6 files changed, 300 insertions(+), 378 deletions(-) create mode 100644 tensorflow/lite/micro/kernels/fully_connected_common.cc diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc index 11a0f0bdc23..6274c152543 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc @@ -30,30 +30,12 @@ namespace tflite { namespace { struct OpData { - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; - // The index of the temporary tensor where the quantized inputs are cached. - int input_quantized_index; + OpDataFullyConnectedReference reference_op_data; + // Index to buffer for optimizations if applicable. int buffer_idx; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t filter_zero_point; - int32_t output_zero_point; }; -constexpr int kInputTensor = 0; -constexpr int kWeightsTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - // TODO(b/169801227): This global struct is needed for the linker to drop unused // code (for example, by using Register_FULLY_CONNECTED_INT8 instead of // Register_FULLY_CONNECTED). @@ -65,24 +47,11 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output, OpData* data) { - TfLiteStatus status = kTfLiteOk; // Set buffer index to a reset value data->buffer_idx = -1; - if (data_type != kTfLiteFloat32) { - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - int exponent; - QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); - data->output_shift = -exponent; - TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( - context, activation, output, &data->output_activation_min, - &data->output_activation_max)); - data->input_zero_point = input->params.zero_point; - data->filter_zero_point = filter->params.zero_point; - data->output_zero_point = output->params.zero_point; - } - return status; + return CalculateOpDataFullyConnectedReference(context, activation, data_type, + input, filter, bias, output, + &(data->reference_op_data)); } void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -98,10 +67,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const auto params = static_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = + GetInput(context, node, kFullyConnectedInputTensor); + const TfLiteTensor* filter = + GetInput(context, node, kFullyConnectedWeightsTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); @@ -153,16 +125,15 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, const RuntimeShape input_shape = tflite::micro::GetTensorShape(input); cmsis_nn_fc_params fc_params; - fc_params.input_offset = -data.input_zero_point; - fc_params.output_offset = data.output_zero_point; - fc_params.filter_offset = -data.filter_zero_point; - fc_params.activation.min = data.output_activation_min; - fc_params.activation.max = data.output_activation_max; + fc_params.input_offset = -data.reference_op_data.input_zero_point; + fc_params.output_offset = data.reference_op_data.output_zero_point; + fc_params.filter_offset = -data.reference_op_data.filter_zero_point; + fc_params.activation.min = data.reference_op_data.output_activation_min; + fc_params.activation.max = data.reference_op_data.output_activation_max; cmsis_nn_per_tensor_quant_params quant_params; - quant_params.multiplier = data.output_multiplier; - // TODO(b/138810107): Figure out whether output shift should be inverted - quant_params.shift = -data.output_shift; + quant_params.multiplier = data.reference_op_data.output_multiplier; + quant_params.shift = data.reference_op_data.output_shift; cmsis_nn_dims input_dims; input_dims.n = batches; @@ -206,110 +177,25 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorData(output)), ARM_MATH_SUCCESS); } else { - tflite::FullyConnectedParams op_params; - op_params.input_offset = -data.input_zero_point; - op_params.weights_offset = -data.filter_zero_point; - op_params.output_offset = data.output_zero_point; - op_params.output_multiplier = data.output_multiplier; - // TODO(b/138810107): Figure out whether output shift should be inverted - op_params.output_shift = -data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - - reference_integer_ops::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + return EvalQuantizedInt8FullyConnectedReference( + context, node, data.reference_op_data, input, filter, bias, output); } return kTfLiteOk; } -TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, - const OpData& data, const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output) { - const int32_t input_offset = -data.input_zero_point; - const int32_t filter_offset = -data.filter_zero_point; - const int32_t output_offset = data.output_zero_point; - - tflite::FullyConnectedParams op_params; - op_params.input_offset = input_offset; - op_params.weights_offset = filter_offset; - op_params.output_offset = output_offset; - op_params.output_multiplier = data.output_multiplier; - // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. - op_params.output_shift = -data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - -#define TF_LITE_FULLY_CONNECTED(output_data_type) \ - reference_ops::FullyConnected( \ - op_params, tflite::micro::GetTensorShape(input), \ - tflite::micro::GetTensorData(input), \ - tflite::micro::GetTensorShape(filter), \ - tflite::micro::GetTensorData(filter), \ - tflite::micro::GetTensorShape(bias), \ - tflite::micro::GetTensorData(bias), \ - tflite::micro::GetTensorShape(output), \ - tflite::micro::GetTensorData(output)) - switch (output->type) { - case kTfLiteUInt8: - TF_LITE_FULLY_CONNECTED(uint8_t); - break; - case kTfLiteInt16: - TF_LITE_FULLY_CONNECTED(int16_t); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(output->type), output->type); - return kTfLiteError; - } - - return kTfLiteOk; -} - -TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteFusedActivation activation, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(activation, &output_activation_min, - &output_activation_max); - tflite::FullyConnectedParams op_params; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; - tflite::reference_ops::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - return kTfLiteOk; -} - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->builtin_data != nullptr); const auto* params = static_cast(node->builtin_data); const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kWeightsTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); const TfLiteEvalTensor* bias = - tflite::micro::GetEvalInput(context, node, kBiasTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); @@ -317,14 +203,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { case kTfLiteFloat32: - return EvalFloat(context, node, params->activation, input, filter, bias, - output); + return EvalFloatFullyConnectedReference(context, node, params->activation, + input, filter, bias, output); case kTfLiteInt8: return EvalQuantizedInt8(context, node, data, input, filter, bias, output); case kTfLiteUInt8: - return EvalQuantized(context, node, data, input, filter, bias, output); + return EvalQuantizedFullyConnectedReference( + context, node, data.reference_op_data, input, filter, bias, output); default: TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", @@ -342,13 +229,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // refactoring. TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) { const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kWeightsTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); const TfLiteEvalTensor* bias = - tflite::micro::GetEvalInput(context, node, kBiasTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index d3fdeacb016..a06ba21e427 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -13,191 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/micro/kernels/fully_connected.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { namespace { -struct OpData { - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; - // The index of the temporary tensor where the quantized inputs are cached. - int input_quantized_index; - // Cached zero point values of tensors. - int32_t input_zero_point; - int32_t filter_zero_point; - int32_t output_zero_point; -}; - -constexpr int kInputTensor = 0; -constexpr int kWeightsTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -TfLiteStatus CalculateOpData(TfLiteContext* context, - TfLiteFusedActivation activation, - TfLiteType data_type, const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output, - OpData* data) { - TfLiteStatus status = kTfLiteOk; - if (data_type != kTfLiteFloat32) { - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - int exponent; - QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); - data->output_shift = -exponent; - TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( - context, activation, output, &data->output_activation_min, - &data->output_activation_max)); - - data->input_zero_point = input->params.zero_point; - data->filter_zero_point = filter->params.zero_point; - data->output_zero_point = output->params.zero_point; - } - return status; -} - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); TFLITE_DCHECK(node->builtin_data != nullptr); - OpData* data = static_cast(node->user_data); + auto* data = static_cast(node->user_data); const auto params = static_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = + GetInput(context, node, kFullyConnectedInputTensor); TF_LITE_ENSURE(context, input != nullptr); - const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* filter = + GetInput(context, node, kFullyConnectedWeightsTensor); TF_LITE_ENSURE(context, filter != nullptr); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor); TF_LITE_ENSURE(context, output != nullptr); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); - return CalculateOpData(context, params->activation, input->type, input, - filter, bias, output, data); -} - -TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, - const OpData& data, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output) { - tflite::FullyConnectedParams op_params; - op_params.input_offset = -data.input_zero_point; - op_params.weights_offset = -data.filter_zero_point; - op_params.output_offset = data.output_zero_point; - op_params.output_multiplier = data.output_multiplier; - // TODO(b/138810107): Figure out whether output shift should be inverted - op_params.output_shift = -data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - - reference_integer_ops::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - return kTfLiteOk; -} - -TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, - const OpData& data, const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output) { - const int32_t input_offset = -data.input_zero_point; - const int32_t filter_offset = -data.filter_zero_point; - const int32_t output_offset = data.output_zero_point; - - tflite::FullyConnectedParams op_params; - op_params.input_offset = input_offset; - op_params.weights_offset = filter_offset; - op_params.output_offset = output_offset; - op_params.output_multiplier = data.output_multiplier; - // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. - op_params.output_shift = -data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - -#define TF_LITE_FULLY_CONNECTED(output_data_type) \ - reference_ops::FullyConnected( \ - op_params, tflite::micro::GetTensorShape(input), \ - tflite::micro::GetTensorData(input), \ - tflite::micro::GetTensorShape(filter), \ - tflite::micro::GetTensorData(filter), \ - tflite::micro::GetTensorShape(bias), \ - tflite::micro::GetTensorData(bias), \ - tflite::micro::GetTensorShape(output), \ - tflite::micro::GetTensorData(output)) - switch (output->type) { - case kTfLiteUInt8: - TF_LITE_FULLY_CONNECTED(uint8_t); - break; - case kTfLiteInt16: - TF_LITE_FULLY_CONNECTED(int16_t); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(output->type), output->type); - return kTfLiteError; - } - - return kTfLiteOk; -} - -TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteFusedActivation activation, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(activation, &output_activation_min, - &output_activation_max); - tflite::FullyConnectedParams op_params; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; - tflite::reference_ops::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - return kTfLiteOk; + return CalculateOpDataFullyConnectedReference(context, params->activation, + input->type, input, filter, + bias, output, data); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { @@ -206,28 +62,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { static_cast(node->builtin_data); const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kWeightsTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); const TfLiteEvalTensor* bias = - tflite::micro::GetEvalInput(context, node, kBiasTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); TFLITE_DCHECK(node->user_data != nullptr); - const OpData& data = *(static_cast(node->user_data)); + const auto& data = + *(static_cast(node->user_data)); // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { case kTfLiteFloat32: - return EvalFloat(context, node, params->activation, input, filter, bias, - output); + return EvalFloatFullyConnectedReference(context, node, params->activation, + input, filter, bias, output); case kTfLiteInt8: - return EvalQuantizedInt8(context, node, data, input, filter, bias, - output); + return EvalQuantizedInt8FullyConnectedReference( + context, node, data, input, filter, bias, output); case kTfLiteUInt8: - return EvalQuantized(context, node, data, input, filter, bias, output); + return EvalQuantizedFullyConnectedReference(context, node, data, input, + filter, bias, output); default: TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", @@ -240,7 +98,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/Init, + return {/*init=*/InitFullyConnectedReference, /*free=*/nullptr, /*prepare=*/Prepare, /*invoke=*/Eval, diff --git a/tensorflow/lite/micro/kernels/fully_connected.h b/tensorflow/lite/micro/kernels/fully_connected.h index 3e6467183fe..38704d908f5 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.h +++ b/tensorflow/lite/micro/kernels/fully_connected.h @@ -15,10 +15,76 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ #define TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ +#include + +#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { +struct OpDataFullyConnectedReference { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // The index of the temporary tensor where the quantized inputs are cached. + int input_quantized_index; + // Cached zero point values of tensors. + int32_t input_zero_point; + int32_t filter_zero_point; + int32_t output_zero_point; + + // Returns a FullyConnectedParams struct with all the parameters needed for a + // quantized fully connected computation. + FullyConnectedParams ToQuantizedParams() const { + FullyConnectedParams op_params; + op_params.input_offset = -input_zero_point; + op_params.weights_offset = -filter_zero_point; + op_params.output_offset = output_zero_point; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + return op_params; + } +}; + +extern const int kFullyConnectedInputTensor; +extern const int kFullyConnectedWeightsTensor; +extern const int kFullyConnectedBiasTensor; +extern const int kFullyConnectedOutputTensor; + +TfLiteStatus CalculateOpDataFullyConnectedReference( + TfLiteContext* context, TfLiteFusedActivation activation, + TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + OpDataFullyConnectedReference* data); + +void* InitFullyConnectedReference(TfLiteContext* context, const char* buffer, + size_t length); + +TfLiteStatus EvalFloatFullyConnectedReference( + TfLiteContext* context, TfLiteNode* node, TfLiteFusedActivation activation, + const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, TfLiteEvalTensor* output); + +TfLiteStatus EvalQuantizedInt8FullyConnectedReference( + TfLiteContext* context, TfLiteNode* node, + const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input, + const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, + TfLiteEvalTensor* output); + +TfLiteStatus EvalQuantizedFullyConnectedReference( + TfLiteContext* context, TfLiteNode* node, + const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input, + const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, + TfLiteEvalTensor* output); + // This is the most generic TfLiteRegistration. The actual supported types may // still be target dependent. The only requirement is that every implementation // (reference or optimized) must define this function. diff --git a/tensorflow/lite/micro/kernels/fully_connected_common.cc b/tensorflow/lite/micro/kernels/fully_connected_common.cc new file mode 100644 index 00000000000..8c550186665 --- /dev/null +++ b/tensorflow/lite/micro/kernels/fully_connected_common.cc @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" + +namespace tflite { +namespace { +#if !defined(XTENSA) +FullyConnectedParams ToFloatParams(TfLiteFusedActivation activation) { + FullyConnectedParams op_params; + CalculateActivationRange(activation, &op_params.float_activation_min, + &op_params.float_activation_max); + return op_params; +} +#endif // !defined(XTENSA) +} // namespace + +const int kFullyConnectedInputTensor = 0; +const int kFullyConnectedWeightsTensor = 1; +const int kFullyConnectedBiasTensor = 2; +const int kFullyConnectedOutputTensor = 0; + +void* InitFullyConnectedReference(TfLiteContext* context, const char* buffer, + size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer( + context, sizeof(OpDataFullyConnectedReference)); +} + +#if !defined(XTENSA) +TfLiteStatus CalculateOpDataFullyConnectedReference( + TfLiteContext* context, TfLiteFusedActivation activation, + TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + OpDataFullyConnectedReference* data) { + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplier(real_multiplier, &data->output_multiplier, + &data->output_shift); + + data->input_zero_point = input->params.zero_point; + data->filter_zero_point = filter->params.zero_point; + data->output_zero_point = output->params.zero_point; + + return CalculateActivationRangeQuantized(context, activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + return kTfLiteOk; +} + +TfLiteStatus EvalFloatFullyConnectedReference( + TfLiteContext* context, TfLiteNode* node, TfLiteFusedActivation activation, + const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { + tflite::reference_ops::FullyConnected( + ToFloatParams(activation), tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; +} + +TfLiteStatus EvalQuantizedInt8FullyConnectedReference( + TfLiteContext* context, TfLiteNode* node, + const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input, + const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, + TfLiteEvalTensor* output) { + reference_integer_ops::FullyConnected( + data.ToQuantizedParams(), tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; +} + +TfLiteStatus EvalQuantizedFullyConnectedReference( + TfLiteContext* context, TfLiteNode* node, + const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input, + const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, + TfLiteEvalTensor* output) { +#define TF_LITE_FULLY_CONNECTED(output_data_type) \ + reference_ops::FullyConnected( \ + data.ToQuantizedParams(), tflite::micro::GetTensorShape(input), \ + tflite::micro::GetTensorData(input), \ + tflite::micro::GetTensorShape(filter), \ + tflite::micro::GetTensorData(filter), \ + tflite::micro::GetTensorShape(bias), \ + tflite::micro::GetTensorData(bias), \ + tflite::micro::GetTensorShape(output), \ + tflite::micro::GetTensorData(output)) + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(int16_t); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(output->type), output->type); + return kTfLiteError; + } + + return kTfLiteOk; +} + +#endif // !defined(XTENSA) + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index a169343ace7..ed4737bf411 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h" #include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" @@ -29,30 +30,6 @@ limitations under the License. namespace tflite { namespace { -struct OpData { - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t filter_zero_point; - int32_t output_zero_point; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; - // The index of the temporary tensor where the quantized inputs are cached. - int input_quantized_index; -}; - -constexpr int kInputTensor = 0; -constexpr int kWeightsTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - #if defined(HIFIMINI) void FullyConnected(const FullyConnectedParams& params, const RuntimeShape& input_shape, const int8_t* input_data, @@ -144,7 +121,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output, - OpData* data) { + OpDataFullyConnectedReference* data) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); @@ -155,28 +132,30 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, QuantizeMultiplier(real_multiplier, &data->output_multiplier, &data->output_shift); #endif + data->input_zero_point = input->params.zero_point; + data->filter_zero_point = filter->params.zero_point; + data->output_zero_point = output->params.zero_point; + return CalculateActivationRangeQuantized(context, activation, output, &data->output_activation_min, &data->output_activation_max); } -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); TFLITE_DCHECK(node->builtin_data != nullptr); - OpData* data = static_cast(node->user_data); + auto* data = static_cast(node->user_data); const auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = + GetInput(context, node, kFullyConnectedInputTensor); + const TfLiteTensor* filter = + GetInput(context, node, kFullyConnectedWeightsTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor); if (input->type != kTfLiteInt8) { TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", @@ -184,36 +163,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } - data->input_zero_point = input->params.zero_point; - data->filter_zero_point = filter->params.zero_point; - data->output_zero_point = output->params.zero_point; - return CalculateOpData(context, params->activation, input->type, input, filter, bias, output, data); } TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, - const OpData& data, + const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { - // TODO(b/154032858): Investigate removing extra copies, and also passing by - // value. TODO(b/155656675): Consider passing OpData by value once it is also - // passed to the FullyConnected function. Until it is copied to a local - // op_param variable, we do not get any latency improvements from passing by - // value. - FullyConnectedParams op_params; - op_params.input_offset = -data.input_zero_point; - op_params.weights_offset = -data.filter_zero_point; - op_params.output_offset = data.output_zero_point; - op_params.output_multiplier = data.output_multiplier; - op_params.output_shift = data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - + // TODO(b/154032858): Investigate removing extra copies (i.e. + // data.ToQuantizedParams), and also passing by value. + // + // TODO(b/155656675): Consider passing OpDataFullyConnectedReference by value + // once it is also passed to the FullyConnected function. Until it is copied + // to a local op_param variable, we do not get any latency improvements from + // passing by value. #if defined(HIFIMINI) - FullyConnected(op_params, tflite::micro::GetTensorShape(input), + FullyConnected(data.ToQuantizedParams(), tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), @@ -223,7 +191,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorData(output)); #else reference_integer_ops::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), + data.ToQuantizedParams(), tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), @@ -238,19 +206,20 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); - const OpData& data = *(static_cast(node->user_data)); + const auto& data = + *(static_cast(node->user_data)); const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kWeightsTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); const TfLiteEvalTensor* bias = - (NumInputs(node) == 3) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) - : nullptr; + (NumInputs(node) == 3) ? tflite::micro::GetEvalInput( + context, node, kFullyConnectedBiasTensor) + : nullptr; TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); return EvalQuantizedInt8(context, node, data, input, filter, bias, output); } @@ -258,7 +227,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/Init, + return {/*init=*/InitFullyConnectedReference, /*free=*/nullptr, /*prepare=*/Prepare, /*invoke=*/Eval, diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index a992e2f4fa2..7de5122e31b 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -328,6 +328,7 @@ tensorflow/lite/micro/kernels/ethosu.cc \ tensorflow/lite/micro/kernels/flexbuffers_generated_data.cc \ tensorflow/lite/micro/kernels/floor.cc \ tensorflow/lite/micro/kernels/fully_connected.cc \ +tensorflow/lite/micro/kernels/fully_connected_common.cc \ tensorflow/lite/micro/kernels/hard_swish.cc \ tensorflow/lite/micro/kernels/kernel_runner.cc \ tensorflow/lite/micro/kernels/kernel_util.cc \ From bbd388b7c327e44c4bad48c89bc2194ae1824b6c Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Thu, 7 Jan 2021 13:04:57 -0800 Subject: [PATCH 2/3] Updates based on discussion with Nat. --- tensorflow/lite/micro/kernels/BUILD | 4 +- .../micro/kernels/cmsis-nn/fully_connected.cc | 64 ++++++++--- .../lite/micro/kernels/fully_connected.cc | 72 +++++++++--- .../lite/micro/kernels/fully_connected.h | 51 +++------ .../micro/kernels/fully_connected_common.cc | 105 ++++-------------- .../micro/kernels/xtensa/fully_connected.cc | 27 +++-- 6 files changed, 154 insertions(+), 169 deletions(-) diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 965578da0ed..911d2971a7f 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -47,7 +47,9 @@ cc_library( cc_library( name = "fully_connected", - srcs = select({ + srcs = [ + "fully_connected_common.cc", + ] + select({ "//conditions:default": [ "fully_connected.cc", ], diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc index 6274c152543..3db06aa35df 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc @@ -13,24 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "CMSIS/NN/Include/arm_nnfunctions.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { namespace { struct OpData { - OpDataFullyConnectedReference reference_op_data; + OpDataFullyConnected reference_op_data; // Index to buffer for optimizations if applicable. int buffer_idx; @@ -49,9 +49,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, OpData* data) { // Set buffer index to a reset value data->buffer_idx = -1; - return CalculateOpDataFullyConnectedReference(context, activation, data_type, - input, filter, bias, output, - &(data->reference_op_data)); + return CalculateOpDataFullyConnected(context, activation, data_type, input, + filter, bias, output, + &(data->reference_op_data)); } void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -177,8 +177,16 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorData(output)), ARM_MATH_SUCCESS); } else { - return EvalQuantizedInt8FullyConnectedReference( - context, node, data.reference_op_data, input, filter, bias, output); + tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data.reference_op_data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } return kTfLiteOk; } @@ -202,21 +210,41 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { - case kTfLiteFloat32: - return EvalFloatFullyConnectedReference(context, node, params->activation, - input, filter, bias, output); - case kTfLiteInt8: + case kTfLiteFloat32: { + tflite::reference_ops::FullyConnected( + FullyConnectedParamsFloat(params->activation), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + case kTfLiteInt8: { return EvalQuantizedInt8(context, node, data, input, filter, bias, output); - - case kTfLiteUInt8: - return EvalQuantizedFullyConnectedReference( - context, node, data.reference_op_data, input, filter, bias, output); - - default: + } + case kTfLiteUInt8: { + tflite::reference_ops::FullyConnected( + FullyConnectedParamsQuantized(data.reference_op_data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: { TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TfLiteTypeGetName(input->type), input->type); return kTfLiteError; + } } return kTfLiteOk; } diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index a06ba21e427..28fbd4860fb 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -13,26 +13,32 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { namespace { +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, + sizeof(OpDataFullyConnected)); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); TFLITE_DCHECK(node->builtin_data != nullptr); - auto* data = static_cast(node->user_data); + auto* data = static_cast(node->user_data); const auto params = static_cast(node->builtin_data); @@ -51,9 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); - return CalculateOpDataFullyConnectedReference(context, params->activation, - input->type, input, filter, - bias, output, data); + return CalculateOpDataFullyConnected(context, params->activation, input->type, + input, filter, bias, output, data); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { @@ -72,25 +77,56 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); const auto& data = - *(static_cast(node->user_data)); + *(static_cast(node->user_data)); // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { - case kTfLiteFloat32: - return EvalFloatFullyConnectedReference(context, node, params->activation, - input, filter, bias, output); - case kTfLiteInt8: - return EvalQuantizedInt8FullyConnectedReference( - context, node, data, input, filter, bias, output); + case kTfLiteFloat32: { + tflite::reference_ops::FullyConnected( + FullyConnectedParamsFloat(params->activation), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } - case kTfLiteUInt8: - return EvalQuantizedFullyConnectedReference(context, node, data, input, - filter, bias, output); + case kTfLiteInt8: { + tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } - default: + case kTfLiteUInt8: { + tflite::reference_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: { TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TfLiteTypeGetName(input->type), input->type); return kTfLiteError; + } } return kTfLiteOk; } @@ -98,7 +134,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/InitFullyConnectedReference, + return {/*init=*/Init, /*free=*/nullptr, /*prepare=*/Prepare, /*invoke=*/Eval, diff --git a/tensorflow/lite/micro/kernels/fully_connected.h b/tensorflow/lite/micro/kernels/fully_connected.h index 38704d908f5..d5e22e51377 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.h +++ b/tensorflow/lite/micro/kernels/fully_connected.h @@ -23,7 +23,7 @@ limitations under the License. namespace tflite { -struct OpDataFullyConnectedReference { +struct OpDataFullyConnected { // The scaling factor from input to output (aka the 'real multiplier') can // be represented as a fixed point multiplier plus a left shift. int32_t output_multiplier; @@ -38,20 +38,6 @@ struct OpDataFullyConnectedReference { int32_t input_zero_point; int32_t filter_zero_point; int32_t output_zero_point; - - // Returns a FullyConnectedParams struct with all the parameters needed for a - // quantized fully connected computation. - FullyConnectedParams ToQuantizedParams() const { - FullyConnectedParams op_params; - op_params.input_offset = -input_zero_point; - op_params.weights_offset = -filter_zero_point; - op_params.output_offset = output_zero_point; - op_params.output_multiplier = output_multiplier; - op_params.output_shift = output_shift; - op_params.quantized_activation_min = output_activation_min; - op_params.quantized_activation_max = output_activation_max; - return op_params; - } }; extern const int kFullyConnectedInputTensor; @@ -59,31 +45,20 @@ extern const int kFullyConnectedWeightsTensor; extern const int kFullyConnectedBiasTensor; extern const int kFullyConnectedOutputTensor; -TfLiteStatus CalculateOpDataFullyConnectedReference( +// Returns a FullyConnectedParams struct with all the parameters needed for a +// float computation. +FullyConnectedParams FullyConnectedParamsFloat( + TfLiteFusedActivation activation); + +// Returns a FullyConnectedParams struct with all the parameters needed for a +// quantized computation. +FullyConnectedParams FullyConnectedParamsQuantized( + const OpDataFullyConnected& op_data); + +TfLiteStatus CalculateOpDataFullyConnected( TfLiteContext* context, TfLiteFusedActivation activation, TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output, - OpDataFullyConnectedReference* data); - -void* InitFullyConnectedReference(TfLiteContext* context, const char* buffer, - size_t length); - -TfLiteStatus EvalFloatFullyConnectedReference( - TfLiteContext* context, TfLiteNode* node, TfLiteFusedActivation activation, - const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, TfLiteEvalTensor* output); - -TfLiteStatus EvalQuantizedInt8FullyConnectedReference( - TfLiteContext* context, TfLiteNode* node, - const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output); - -TfLiteStatus EvalQuantizedFullyConnectedReference( - TfLiteContext* context, TfLiteNode* node, - const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output); + const TfLiteTensor* bias, TfLiteTensor* output, OpDataFullyConnected* data); // This is the most generic TfLiteRegistration. The actual supported types may // still be target dependent. The only requirement is that every implementation diff --git a/tensorflow/lite/micro/kernels/fully_connected_common.cc b/tensorflow/lite/micro/kernels/fully_connected_common.cc index 8c550186665..64046a9cec3 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_common.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_common.cc @@ -25,35 +25,38 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace { -#if !defined(XTENSA) -FullyConnectedParams ToFloatParams(TfLiteFusedActivation activation) { - FullyConnectedParams op_params; - CalculateActivationRange(activation, &op_params.float_activation_min, - &op_params.float_activation_max); - return op_params; -} -#endif // !defined(XTENSA) -} // namespace const int kFullyConnectedInputTensor = 0; const int kFullyConnectedWeightsTensor = 1; const int kFullyConnectedBiasTensor = 2; const int kFullyConnectedOutputTensor = 0; -void* InitFullyConnectedReference(TfLiteContext* context, const char* buffer, - size_t length) { - TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer( - context, sizeof(OpDataFullyConnectedReference)); +FullyConnectedParams FullyConnectedParamsQuantized( + const OpDataFullyConnected& op_data) { + FullyConnectedParams op_params; + op_params.input_offset = -op_data.input_zero_point; + op_params.weights_offset = -op_data.filter_zero_point; + op_params.output_offset = op_data.output_zero_point; + op_params.output_multiplier = op_data.output_multiplier; + op_params.output_shift = op_data.output_shift; + op_params.quantized_activation_min = op_data.output_activation_min; + op_params.quantized_activation_max = op_data.output_activation_max; + return op_params; } -#if !defined(XTENSA) -TfLiteStatus CalculateOpDataFullyConnectedReference( +FullyConnectedParams FullyConnectedParamsFloat( + TfLiteFusedActivation activation) { + FullyConnectedParams op_params; + CalculateActivationRange(activation, &op_params.float_activation_min, + &op_params.float_activation_max); + return op_params; +} + +TfLiteStatus CalculateOpDataFullyConnected( TfLiteContext* context, TfLiteFusedActivation activation, TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output, - OpDataFullyConnectedReference* data) { + OpDataFullyConnected* data) { if (data_type != kTfLiteFloat32) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( @@ -72,70 +75,4 @@ TfLiteStatus CalculateOpDataFullyConnectedReference( return kTfLiteOk; } -TfLiteStatus EvalFloatFullyConnectedReference( - TfLiteContext* context, TfLiteNode* node, TfLiteFusedActivation activation, - const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { - tflite::reference_ops::FullyConnected( - ToFloatParams(activation), tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - return kTfLiteOk; -} - -TfLiteStatus EvalQuantizedInt8FullyConnectedReference( - TfLiteContext* context, TfLiteNode* node, - const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output) { - reference_integer_ops::FullyConnected( - data.ToQuantizedParams(), tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - return kTfLiteOk; -} - -TfLiteStatus EvalQuantizedFullyConnectedReference( - TfLiteContext* context, TfLiteNode* node, - const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output) { -#define TF_LITE_FULLY_CONNECTED(output_data_type) \ - reference_ops::FullyConnected( \ - data.ToQuantizedParams(), tflite::micro::GetTensorShape(input), \ - tflite::micro::GetTensorData(input), \ - tflite::micro::GetTensorShape(filter), \ - tflite::micro::GetTensorData(filter), \ - tflite::micro::GetTensorShape(bias), \ - tflite::micro::GetTensorData(bias), \ - tflite::micro::GetTensorShape(output), \ - tflite::micro::GetTensorData(output)) - switch (output->type) { - case kTfLiteUInt8: - TF_LITE_FULLY_CONNECTED(uint8_t); - break; - case kTfLiteInt16: - TF_LITE_FULLY_CONNECTED(int16_t); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(output->type), output->type); - return kTfLiteError; - } - - return kTfLiteOk; -} - -#endif // !defined(XTENSA) - } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index ed4737bf411..5a2520ccec6 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/xtensa/fixedpoint_utils.h" #include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" @@ -121,7 +121,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output, - OpDataFullyConnectedReference* data) { + OpDataFullyConnected* data) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); @@ -141,11 +141,17 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, &data->output_activation_max); } +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, + sizeof(OpDataFullyConnected)); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); TFLITE_DCHECK(node->builtin_data != nullptr); - auto* data = static_cast(node->user_data); + auto* data = static_cast(node->user_data); const auto* params = reinterpret_cast(node->builtin_data); @@ -168,7 +174,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, - const OpDataFullyConnectedReference& data, + const OpDataFullyConnected& data, const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, @@ -176,12 +182,13 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, // TODO(b/154032858): Investigate removing extra copies (i.e. // data.ToQuantizedParams), and also passing by value. // - // TODO(b/155656675): Consider passing OpDataFullyConnectedReference by value + // TODO(b/155656675): Consider passing OpDataFullyConnected by value // once it is also passed to the FullyConnected function. Until it is copied // to a local op_param variable, we do not get any latency improvements from // passing by value. #if defined(HIFIMINI) - FullyConnected(data.ToQuantizedParams(), tflite::micro::GetTensorShape(input), + FullyConnected(FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), @@ -191,7 +198,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorData(output)); #else reference_integer_ops::FullyConnected( - data.ToQuantizedParams(), tflite::micro::GetTensorShape(input), + FullyConnectedParamsQuantized(data), tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), @@ -207,7 +214,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); const auto& data = - *(static_cast(node->user_data)); + *(static_cast(node->user_data)); const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); @@ -227,7 +234,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/InitFullyConnectedReference, + return {/*init=*/Init, /*free=*/nullptr, /*prepare=*/Prepare, /*invoke=*/Eval, From c8bc2530ebaad3792e4abee8355db8f3cc45ad8d Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Mon, 11 Jan 2021 20:51:50 -0800 Subject: [PATCH 3/3] address review comments. --- .../micro/kernels/cmsis-nn/fully_connected.cc | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc index 3db06aa35df..6e2a2980952 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc @@ -41,19 +41,6 @@ struct OpData { // Register_FULLY_CONNECTED). TfLiteRegistration fully_connected_registration; -TfLiteStatus CalculateOpData(TfLiteContext* context, - TfLiteFusedActivation activation, - TfLiteType data_type, const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output, - OpData* data) { - // Set buffer index to a reset value - data->buffer_idx = -1; - return CalculateOpDataFullyConnected(context, activation, data_type, input, - filter, bias, output, - &(data->reference_op_data)); -} - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -77,9 +64,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params->activation, - input->type, input, filter, bias, - output, data)); + + // Set buffer index to a reset value + data->buffer_idx = -1; + TF_LITE_ENSURE_STATUS(CalculateOpDataFullyConnected( + context, params->activation, input->type, input, filter, bias, output, + &(data->reference_op_data))); if (input->type == kTfLiteInt8 && nullptr != GetTensorData(bias)) { RuntimeShape filter_shape = GetTensorShape(filter);