diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index a35732ab316..856f9ecfe32 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -50,7 +50,9 @@ cc_library( cc_library( name = "fully_connected", - srcs = select({ + srcs = [ + "fully_connected_common.cc", + ] + select({ "//conditions:default": [ "fully_connected.cc", ], diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc index 11a0f0bdc23..6e2a2980952 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc @@ -13,78 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "CMSIS/NN/Include/arm_nnfunctions.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { namespace { struct OpData { - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; - // The index of the temporary tensor where the quantized inputs are cached. - int input_quantized_index; + OpDataFullyConnected reference_op_data; + // Index to buffer for optimizations if applicable. int buffer_idx; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t filter_zero_point; - int32_t output_zero_point; }; -constexpr int kInputTensor = 0; -constexpr int kWeightsTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - // TODO(b/169801227): This global struct is needed for the linker to drop unused // code (for example, by using Register_FULLY_CONNECTED_INT8 instead of // Register_FULLY_CONNECTED). TfLiteRegistration fully_connected_registration; -TfLiteStatus CalculateOpData(TfLiteContext* context, - TfLiteFusedActivation activation, - TfLiteType data_type, const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output, - OpData* data) { - TfLiteStatus status = kTfLiteOk; - // Set buffer index to a reset value - data->buffer_idx = -1; - if (data_type != kTfLiteFloat32) { - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - int exponent; - QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); - data->output_shift = -exponent; - TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( - context, activation, output, &data->output_activation_min, - &data->output_activation_max)); - data->input_zero_point = input->params.zero_point; - data->filter_zero_point = filter->params.zero_point; - data->output_zero_point = output->params.zero_point; - } - return status; -} - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -98,16 +54,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const auto params = static_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = + GetInput(context, node, kFullyConnectedInputTensor); + const TfLiteTensor* filter = + GetInput(context, node, kFullyConnectedWeightsTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params->activation, - input->type, input, filter, bias, - output, data)); + + // Set buffer index to a reset value + data->buffer_idx = -1; + TF_LITE_ENSURE_STATUS(CalculateOpDataFullyConnected( + context, params->activation, input->type, input, filter, bias, output, + &(data->reference_op_data))); if (input->type == kTfLiteInt8 && nullptr != GetTensorData(bias)) { RuntimeShape filter_shape = GetTensorShape(filter); @@ -153,16 +115,15 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, const RuntimeShape input_shape = tflite::micro::GetTensorShape(input); cmsis_nn_fc_params fc_params; - fc_params.input_offset = -data.input_zero_point; - fc_params.output_offset = data.output_zero_point; - fc_params.filter_offset = -data.filter_zero_point; - fc_params.activation.min = data.output_activation_min; - fc_params.activation.max = data.output_activation_max; + fc_params.input_offset = -data.reference_op_data.input_zero_point; + fc_params.output_offset = data.reference_op_data.output_zero_point; + fc_params.filter_offset = -data.reference_op_data.filter_zero_point; + fc_params.activation.min = data.reference_op_data.output_activation_min; + fc_params.activation.max = data.reference_op_data.output_activation_max; cmsis_nn_per_tensor_quant_params quant_params; - quant_params.multiplier = data.output_multiplier; - // TODO(b/138810107): Figure out whether output shift should be inverted - quant_params.shift = -data.output_shift; + quant_params.multiplier = data.reference_op_data.output_multiplier; + quant_params.shift = data.reference_op_data.output_shift; cmsis_nn_dims input_dims; input_dims.n = batches; @@ -206,18 +167,9 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorData(output)), ARM_MATH_SUCCESS); } else { - tflite::FullyConnectedParams op_params; - op_params.input_offset = -data.input_zero_point; - op_params.weights_offset = -data.filter_zero_point; - op_params.output_offset = data.output_zero_point; - op_params.output_multiplier = data.output_multiplier; - // TODO(b/138810107): Figure out whether output shift should be inverted - op_params.output_shift = -data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - - reference_integer_ops::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), + tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data.reference_op_data), + tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), @@ -229,107 +181,60 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, - const OpData& data, const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output) { - const int32_t input_offset = -data.input_zero_point; - const int32_t filter_offset = -data.filter_zero_point; - const int32_t output_offset = data.output_zero_point; - - tflite::FullyConnectedParams op_params; - op_params.input_offset = input_offset; - op_params.weights_offset = filter_offset; - op_params.output_offset = output_offset; - op_params.output_multiplier = data.output_multiplier; - // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. - op_params.output_shift = -data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - -#define TF_LITE_FULLY_CONNECTED(output_data_type) \ - reference_ops::FullyConnected( \ - op_params, tflite::micro::GetTensorShape(input), \ - tflite::micro::GetTensorData(input), \ - tflite::micro::GetTensorShape(filter), \ - tflite::micro::GetTensorData(filter), \ - tflite::micro::GetTensorShape(bias), \ - tflite::micro::GetTensorData(bias), \ - tflite::micro::GetTensorShape(output), \ - tflite::micro::GetTensorData(output)) - switch (output->type) { - case kTfLiteUInt8: - TF_LITE_FULLY_CONNECTED(uint8_t); - break; - case kTfLiteInt16: - TF_LITE_FULLY_CONNECTED(int16_t); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(output->type), output->type); - return kTfLiteError; - } - - return kTfLiteOk; -} - -TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteFusedActivation activation, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(activation, &output_activation_min, - &output_activation_max); - tflite::FullyConnectedParams op_params; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; - tflite::reference_ops::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - return kTfLiteOk; -} - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->builtin_data != nullptr); const auto* params = static_cast(node->builtin_data); const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kWeightsTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); const TfLiteEvalTensor* bias = - tflite::micro::GetEvalInput(context, node, kBiasTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { - case kTfLiteFloat32: - return EvalFloat(context, node, params->activation, input, filter, bias, - output); - case kTfLiteInt8: + case kTfLiteFloat32: { + tflite::reference_ops::FullyConnected( + FullyConnectedParamsFloat(params->activation), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + case kTfLiteInt8: { return EvalQuantizedInt8(context, node, data, input, filter, bias, output); - - case kTfLiteUInt8: - return EvalQuantized(context, node, data, input, filter, bias, output); - - default: + } + case kTfLiteUInt8: { + tflite::reference_ops::FullyConnected( + FullyConnectedParamsQuantized(data.reference_op_data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: { TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TfLiteTypeGetName(input->type), input->type); return kTfLiteError; + } } return kTfLiteOk; } @@ -342,13 +247,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // refactoring. TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) { const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kWeightsTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); const TfLiteEvalTensor* bias = - tflite::micro::GetEvalInput(context, node, kBiasTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index d3fdeacb016..28fbd4860fb 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -28,176 +28,37 @@ limitations under the License. namespace tflite { namespace { -struct OpData { - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; - // The index of the temporary tensor where the quantized inputs are cached. - int input_quantized_index; - // Cached zero point values of tensors. - int32_t input_zero_point; - int32_t filter_zero_point; - int32_t output_zero_point; -}; - -constexpr int kInputTensor = 0; -constexpr int kWeightsTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -TfLiteStatus CalculateOpData(TfLiteContext* context, - TfLiteFusedActivation activation, - TfLiteType data_type, const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output, - OpData* data) { - TfLiteStatus status = kTfLiteOk; - if (data_type != kTfLiteFloat32) { - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - int exponent; - QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); - data->output_shift = -exponent; - TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( - context, activation, output, &data->output_activation_min, - &data->output_activation_max)); - - data->input_zero_point = input->params.zero_point; - data->filter_zero_point = filter->params.zero_point; - data->output_zero_point = output->params.zero_point; - } - return status; -} - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); + return context->AllocatePersistentBuffer(context, + sizeof(OpDataFullyConnected)); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); TFLITE_DCHECK(node->builtin_data != nullptr); - OpData* data = static_cast(node->user_data); + auto* data = static_cast(node->user_data); const auto params = static_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* input = + GetInput(context, node, kFullyConnectedInputTensor); TF_LITE_ENSURE(context, input != nullptr); - const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* filter = + GetInput(context, node, kFullyConnectedWeightsTensor); TF_LITE_ENSURE(context, filter != nullptr); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor); TF_LITE_ENSURE(context, output != nullptr); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); - return CalculateOpData(context, params->activation, input->type, input, - filter, bias, output, data); -} - -TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, - const OpData& data, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output) { - tflite::FullyConnectedParams op_params; - op_params.input_offset = -data.input_zero_point; - op_params.weights_offset = -data.filter_zero_point; - op_params.output_offset = data.output_zero_point; - op_params.output_multiplier = data.output_multiplier; - // TODO(b/138810107): Figure out whether output shift should be inverted - op_params.output_shift = -data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - - reference_integer_ops::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - return kTfLiteOk; -} - -TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, - const OpData& data, const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output) { - const int32_t input_offset = -data.input_zero_point; - const int32_t filter_offset = -data.filter_zero_point; - const int32_t output_offset = data.output_zero_point; - - tflite::FullyConnectedParams op_params; - op_params.input_offset = input_offset; - op_params.weights_offset = filter_offset; - op_params.output_offset = output_offset; - op_params.output_multiplier = data.output_multiplier; - // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. - op_params.output_shift = -data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - -#define TF_LITE_FULLY_CONNECTED(output_data_type) \ - reference_ops::FullyConnected( \ - op_params, tflite::micro::GetTensorShape(input), \ - tflite::micro::GetTensorData(input), \ - tflite::micro::GetTensorShape(filter), \ - tflite::micro::GetTensorData(filter), \ - tflite::micro::GetTensorShape(bias), \ - tflite::micro::GetTensorData(bias), \ - tflite::micro::GetTensorShape(output), \ - tflite::micro::GetTensorData(output)) - switch (output->type) { - case kTfLiteUInt8: - TF_LITE_FULLY_CONNECTED(uint8_t); - break; - case kTfLiteInt16: - TF_LITE_FULLY_CONNECTED(int16_t); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(output->type), output->type); - return kTfLiteError; - } - - return kTfLiteOk; -} - -TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteFusedActivation activation, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(activation, &output_activation_min, - &output_activation_max); - tflite::FullyConnectedParams op_params; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; - tflite::reference_ops::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - return kTfLiteOk; + return CalculateOpDataFullyConnected(context, params->activation, input->type, + input, filter, bias, output, data); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { @@ -206,33 +67,66 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { static_cast(node->builtin_data); const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kWeightsTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); const TfLiteEvalTensor* bias = - tflite::micro::GetEvalInput(context, node, kBiasTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); TFLITE_DCHECK(node->user_data != nullptr); - const OpData& data = *(static_cast(node->user_data)); + const auto& data = + *(static_cast(node->user_data)); // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { - case kTfLiteFloat32: - return EvalFloat(context, node, params->activation, input, filter, bias, - output); - case kTfLiteInt8: - return EvalQuantizedInt8(context, node, data, input, filter, bias, - output); + case kTfLiteFloat32: { + tflite::reference_ops::FullyConnected( + FullyConnectedParamsFloat(params->activation), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } - case kTfLiteUInt8: - return EvalQuantized(context, node, data, input, filter, bias, output); + case kTfLiteInt8: { + tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } - default: + case kTfLiteUInt8: { + tflite::reference_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: { TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TfLiteTypeGetName(input->type), input->type); return kTfLiteError; + } } return kTfLiteOk; } diff --git a/tensorflow/lite/micro/kernels/fully_connected.h b/tensorflow/lite/micro/kernels/fully_connected.h index 3e6467183fe..d5e22e51377 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.h +++ b/tensorflow/lite/micro/kernels/fully_connected.h @@ -15,10 +15,51 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ #define TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ +#include + +#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { +struct OpDataFullyConnected { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // The index of the temporary tensor where the quantized inputs are cached. + int input_quantized_index; + // Cached zero point values of tensors. + int32_t input_zero_point; + int32_t filter_zero_point; + int32_t output_zero_point; +}; + +extern const int kFullyConnectedInputTensor; +extern const int kFullyConnectedWeightsTensor; +extern const int kFullyConnectedBiasTensor; +extern const int kFullyConnectedOutputTensor; + +// Returns a FullyConnectedParams struct with all the parameters needed for a +// float computation. +FullyConnectedParams FullyConnectedParamsFloat( + TfLiteFusedActivation activation); + +// Returns a FullyConnectedParams struct with all the parameters needed for a +// quantized computation. +FullyConnectedParams FullyConnectedParamsQuantized( + const OpDataFullyConnected& op_data); + +TfLiteStatus CalculateOpDataFullyConnected( + TfLiteContext* context, TfLiteFusedActivation activation, + TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, OpDataFullyConnected* data); + // This is the most generic TfLiteRegistration. The actual supported types may // still be target dependent. The only requirement is that every implementation // (reference or optimized) must define this function. diff --git a/tensorflow/lite/micro/kernels/fully_connected_common.cc b/tensorflow/lite/micro/kernels/fully_connected_common.cc new file mode 100644 index 00000000000..64046a9cec3 --- /dev/null +++ b/tensorflow/lite/micro/kernels/fully_connected_common.cc @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" + +namespace tflite { + +const int kFullyConnectedInputTensor = 0; +const int kFullyConnectedWeightsTensor = 1; +const int kFullyConnectedBiasTensor = 2; +const int kFullyConnectedOutputTensor = 0; + +FullyConnectedParams FullyConnectedParamsQuantized( + const OpDataFullyConnected& op_data) { + FullyConnectedParams op_params; + op_params.input_offset = -op_data.input_zero_point; + op_params.weights_offset = -op_data.filter_zero_point; + op_params.output_offset = op_data.output_zero_point; + op_params.output_multiplier = op_data.output_multiplier; + op_params.output_shift = op_data.output_shift; + op_params.quantized_activation_min = op_data.output_activation_min; + op_params.quantized_activation_max = op_data.output_activation_max; + return op_params; +} + +FullyConnectedParams FullyConnectedParamsFloat( + TfLiteFusedActivation activation) { + FullyConnectedParams op_params; + CalculateActivationRange(activation, &op_params.float_activation_min, + &op_params.float_activation_max); + return op_params; +} + +TfLiteStatus CalculateOpDataFullyConnected( + TfLiteContext* context, TfLiteFusedActivation activation, + TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + OpDataFullyConnected* data) { + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplier(real_multiplier, &data->output_multiplier, + &data->output_shift); + + data->input_zero_point = input->params.zero_point; + data->filter_zero_point = filter->params.zero_point; + data->output_zero_point = output->params.zero_point; + + return CalculateActivationRangeQuantized(context, activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index a169343ace7..5a2520ccec6 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -29,30 +30,6 @@ limitations under the License. namespace tflite { namespace { -struct OpData { - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t filter_zero_point; - int32_t output_zero_point; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; - // The index of the temporary tensor where the quantized inputs are cached. - int input_quantized_index; -}; - -constexpr int kInputTensor = 0; -constexpr int kWeightsTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - #if defined(HIFIMINI) void FullyConnected(const FullyConnectedParams& params, const RuntimeShape& input_shape, const int8_t* input_data, @@ -144,7 +121,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output, - OpData* data) { + OpDataFullyConnected* data) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); @@ -155,6 +132,10 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, QuantizeMultiplier(real_multiplier, &data->output_multiplier, &data->output_shift); #endif + data->input_zero_point = input->params.zero_point; + data->filter_zero_point = filter->params.zero_point; + data->output_zero_point = output->params.zero_point; + return CalculateActivationRangeQuantized(context, activation, output, &data->output_activation_min, &data->output_activation_max); @@ -162,21 +143,25 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); + return context->AllocatePersistentBuffer(context, + sizeof(OpDataFullyConnected)); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); TFLITE_DCHECK(node->builtin_data != nullptr); - OpData* data = static_cast(node->user_data); + auto* data = static_cast(node->user_data); const auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = + GetInput(context, node, kFullyConnectedInputTensor); + const TfLiteTensor* filter = + GetInput(context, node, kFullyConnectedWeightsTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor); if (input->type != kTfLiteInt8) { TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", @@ -184,36 +169,26 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } - data->input_zero_point = input->params.zero_point; - data->filter_zero_point = filter->params.zero_point; - data->output_zero_point = output->params.zero_point; - return CalculateOpData(context, params->activation, input->type, input, filter, bias, output, data); } TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, - const OpData& data, + const OpDataFullyConnected& data, const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) { - // TODO(b/154032858): Investigate removing extra copies, and also passing by - // value. TODO(b/155656675): Consider passing OpData by value once it is also - // passed to the FullyConnected function. Until it is copied to a local - // op_param variable, we do not get any latency improvements from passing by - // value. - FullyConnectedParams op_params; - op_params.input_offset = -data.input_zero_point; - op_params.weights_offset = -data.filter_zero_point; - op_params.output_offset = data.output_zero_point; - op_params.output_multiplier = data.output_multiplier; - op_params.output_shift = data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - + // TODO(b/154032858): Investigate removing extra copies (i.e. + // data.ToQuantizedParams), and also passing by value. + // + // TODO(b/155656675): Consider passing OpDataFullyConnected by value + // once it is also passed to the FullyConnected function. Until it is copied + // to a local op_param variable, we do not get any latency improvements from + // passing by value. #if defined(HIFIMINI) - FullyConnected(op_params, tflite::micro::GetTensorShape(input), + FullyConnected(FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), @@ -223,7 +198,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorData(output)); #else reference_integer_ops::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), + FullyConnectedParamsQuantized(data), tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), @@ -238,19 +213,20 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); - const OpData& data = *(static_cast(node->user_data)); + const auto& data = + *(static_cast(node->user_data)); const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kWeightsTensor); + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); const TfLiteEvalTensor* bias = - (NumInputs(node) == 3) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) - : nullptr; + (NumInputs(node) == 3) ? tflite::micro::GetEvalInput( + context, node, kFullyConnectedBiasTensor) + : nullptr; TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); return EvalQuantizedInt8(context, node, data, input, filter, bias, output); } diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 036102766d3..03b8e0d88f3 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -329,6 +329,7 @@ tensorflow/lite/micro/kernels/ethosu.cc \ tensorflow/lite/micro/kernels/flexbuffers_generated_data.cc \ tensorflow/lite/micro/kernels/floor.cc \ tensorflow/lite/micro/kernels/fully_connected.cc \ +tensorflow/lite/micro/kernels/fully_connected_common.cc \ tensorflow/lite/micro/kernels/hard_swish.cc \ tensorflow/lite/micro/kernels/kernel_runner.cc \ tensorflow/lite/micro/kernels/kernel_util.cc \