Merge pull request #46242 from advaitjain:refactor-fully_connected
PiperOrigin-RevId: 351435546 Change-Id: I0dad3dde99541fab9d4ed92fec1e25b78e2b3f9d
This commit is contained in:
commit
88e57640cd
@ -50,7 +50,9 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "fully_connected",
|
||||
srcs = select({
|
||||
srcs = [
|
||||
"fully_connected_common.cc",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"fully_connected.cc",
|
||||
],
|
||||
|
||||
@ -13,78 +13,34 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
|
||||
#include "CMSIS/NN/Include/arm_nnfunctions.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
struct OpData {
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
// uint8_t these would be 0 and 255.
|
||||
int32_t output_activation_min;
|
||||
int32_t output_activation_max;
|
||||
// The index of the temporary tensor where the quantized inputs are cached.
|
||||
int input_quantized_index;
|
||||
OpDataFullyConnected reference_op_data;
|
||||
|
||||
// Index to buffer for optimizations if applicable.
|
||||
int buffer_idx;
|
||||
|
||||
// Cached tensor zero point values for quantized operations.
|
||||
int32_t input_zero_point;
|
||||
int32_t filter_zero_point;
|
||||
int32_t output_zero_point;
|
||||
};
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kWeightsTensor = 1;
|
||||
constexpr int kBiasTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
// TODO(b/169801227): This global struct is needed for the linker to drop unused
|
||||
// code (for example, by using Register_FULLY_CONNECTED_INT8 instead of
|
||||
// Register_FULLY_CONNECTED).
|
||||
TfLiteRegistration fully_connected_registration;
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
TfLiteFusedActivation activation,
|
||||
TfLiteType data_type, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||
OpData* data) {
|
||||
TfLiteStatus status = kTfLiteOk;
|
||||
// Set buffer index to a reset value
|
||||
data->buffer_idx = -1;
|
||||
if (data_type != kTfLiteFloat32) {
|
||||
double real_multiplier = 0.0;
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||
context, input, filter, bias, output, &real_multiplier));
|
||||
int exponent;
|
||||
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
|
||||
data->output_shift = -exponent;
|
||||
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
|
||||
context, activation, output, &data->output_activation_min,
|
||||
&data->output_activation_max));
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
data->filter_zero_point = filter->params.zero_point;
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
@ -98,16 +54,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto params =
|
||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input =
|
||||
GetInput(context, node, kFullyConnectedInputTensor);
|
||||
const TfLiteTensor* filter =
|
||||
GetInput(context, node, kFullyConnectedWeightsTensor);
|
||||
const TfLiteTensor* bias =
|
||||
GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
||||
"Hybrid models are not supported on TFLite Micro.");
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params->activation,
|
||||
input->type, input, filter, bias,
|
||||
output, data));
|
||||
|
||||
// Set buffer index to a reset value
|
||||
data->buffer_idx = -1;
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpDataFullyConnected(
|
||||
context, params->activation, input->type, input, filter, bias, output,
|
||||
&(data->reference_op_data)));
|
||||
|
||||
if (input->type == kTfLiteInt8 && nullptr != GetTensorData<int32_t>(bias)) {
|
||||
RuntimeShape filter_shape = GetTensorShape(filter);
|
||||
@ -153,16 +115,15 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
const RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
|
||||
|
||||
cmsis_nn_fc_params fc_params;
|
||||
fc_params.input_offset = -data.input_zero_point;
|
||||
fc_params.output_offset = data.output_zero_point;
|
||||
fc_params.filter_offset = -data.filter_zero_point;
|
||||
fc_params.activation.min = data.output_activation_min;
|
||||
fc_params.activation.max = data.output_activation_max;
|
||||
fc_params.input_offset = -data.reference_op_data.input_zero_point;
|
||||
fc_params.output_offset = data.reference_op_data.output_zero_point;
|
||||
fc_params.filter_offset = -data.reference_op_data.filter_zero_point;
|
||||
fc_params.activation.min = data.reference_op_data.output_activation_min;
|
||||
fc_params.activation.max = data.reference_op_data.output_activation_max;
|
||||
|
||||
cmsis_nn_per_tensor_quant_params quant_params;
|
||||
quant_params.multiplier = data.output_multiplier;
|
||||
// TODO(b/138810107): Figure out whether output shift should be inverted
|
||||
quant_params.shift = -data.output_shift;
|
||||
quant_params.multiplier = data.reference_op_data.output_multiplier;
|
||||
quant_params.shift = data.reference_op_data.output_shift;
|
||||
|
||||
cmsis_nn_dims input_dims;
|
||||
input_dims.n = batches;
|
||||
@ -206,18 +167,9 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
tflite::micro::GetTensorData<int8_t>(output)),
|
||||
ARM_MATH_SUCCESS);
|
||||
} else {
|
||||
tflite::FullyConnectedParams op_params;
|
||||
op_params.input_offset = -data.input_zero_point;
|
||||
op_params.weights_offset = -data.filter_zero_point;
|
||||
op_params.output_offset = data.output_zero_point;
|
||||
op_params.output_multiplier = data.output_multiplier;
|
||||
// TODO(b/138810107): Figure out whether output shift should be inverted
|
||||
op_params.output_shift = -data.output_shift;
|
||||
op_params.quantized_activation_min = data.output_activation_min;
|
||||
op_params.quantized_activation_max = data.output_activation_max;
|
||||
|
||||
reference_integer_ops::FullyConnected(
|
||||
op_params, tflite::micro::GetTensorShape(input),
|
||||
tflite::reference_integer_ops::FullyConnected(
|
||||
FullyConnectedParamsQuantized(data.reference_op_data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
@ -229,107 +181,60 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
const OpData& data, const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter,
|
||||
const TfLiteEvalTensor* bias,
|
||||
TfLiteEvalTensor* output) {
|
||||
const int32_t input_offset = -data.input_zero_point;
|
||||
const int32_t filter_offset = -data.filter_zero_point;
|
||||
const int32_t output_offset = data.output_zero_point;
|
||||
|
||||
tflite::FullyConnectedParams op_params;
|
||||
op_params.input_offset = input_offset;
|
||||
op_params.weights_offset = filter_offset;
|
||||
op_params.output_offset = output_offset;
|
||||
op_params.output_multiplier = data.output_multiplier;
|
||||
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
|
||||
op_params.output_shift = -data.output_shift;
|
||||
op_params.quantized_activation_min = data.output_activation_min;
|
||||
op_params.quantized_activation_max = data.output_activation_max;
|
||||
|
||||
#define TF_LITE_FULLY_CONNECTED(output_data_type) \
|
||||
reference_ops::FullyConnected( \
|
||||
op_params, tflite::micro::GetTensorShape(input), \
|
||||
tflite::micro::GetTensorData<uint8_t>(input), \
|
||||
tflite::micro::GetTensorShape(filter), \
|
||||
tflite::micro::GetTensorData<uint8_t>(filter), \
|
||||
tflite::micro::GetTensorShape(bias), \
|
||||
tflite::micro::GetTensorData<int32_t>(bias), \
|
||||
tflite::micro::GetTensorShape(output), \
|
||||
tflite::micro::GetTensorData<output_data_type>(output))
|
||||
switch (output->type) {
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_FULLY_CONNECTED(uint8_t);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
TF_LITE_FULLY_CONNECTED(int16_t);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteFusedActivation activation,
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter,
|
||||
const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) {
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
tflite::FullyConnectedParams op_params;
|
||||
op_params.float_activation_min = output_activation_min;
|
||||
op_params.float_activation_max = output_activation_max;
|
||||
tflite::reference_ops::FullyConnected(
|
||||
op_params, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<float>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<float>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
const auto* params =
|
||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kWeightsTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
tflite::micro::GetEvalInput(context, node, kBiasTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
|
||||
// Checks in Prepare ensure input, output and filter types are all the same.
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
return EvalFloat(context, node, params->activation, input, filter, bias,
|
||||
output);
|
||||
case kTfLiteInt8:
|
||||
case kTfLiteFloat32: {
|
||||
tflite::reference_ops::FullyConnected(
|
||||
FullyConnectedParamsFloat(params->activation),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<float>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<float>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
return EvalQuantizedInt8(context, node, data, input, filter, bias,
|
||||
output);
|
||||
|
||||
case kTfLiteUInt8:
|
||||
return EvalQuantized(context, node, data, input, filter, bias, output);
|
||||
|
||||
default:
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
tflite::reference_ops::FullyConnected(
|
||||
FullyConnectedParamsQuantized(data.reference_op_data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<uint8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<uint8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<uint8_t>(output));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -342,13 +247,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// refactoring.
|
||||
TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kWeightsTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
tflite::micro::GetEvalInput(context, node, kBiasTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
|
||||
@ -28,176 +28,37 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
struct OpData {
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
// uint8_t these would be 0 and 255.
|
||||
int32_t output_activation_min;
|
||||
int32_t output_activation_max;
|
||||
// The index of the temporary tensor where the quantized inputs are cached.
|
||||
int input_quantized_index;
|
||||
// Cached zero point values of tensors.
|
||||
int32_t input_zero_point;
|
||||
int32_t filter_zero_point;
|
||||
int32_t output_zero_point;
|
||||
};
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kWeightsTensor = 1;
|
||||
constexpr int kBiasTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
TfLiteFusedActivation activation,
|
||||
TfLiteType data_type, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||
OpData* data) {
|
||||
TfLiteStatus status = kTfLiteOk;
|
||||
if (data_type != kTfLiteFloat32) {
|
||||
double real_multiplier = 0.0;
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||
context, input, filter, bias, output, &real_multiplier));
|
||||
int exponent;
|
||||
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
|
||||
data->output_shift = -exponent;
|
||||
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
|
||||
context, activation, output, &data->output_activation_min,
|
||||
&data->output_activation_max));
|
||||
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
data->filter_zero_point = filter->params.zero_point;
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
return context->AllocatePersistentBuffer(context,
|
||||
sizeof(OpDataFullyConnected));
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
auto* data = static_cast<OpDataFullyConnected*>(node->user_data);
|
||||
const auto params =
|
||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* input =
|
||||
GetInput(context, node, kFullyConnectedInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* filter =
|
||||
GetInput(context, node, kFullyConnectedWeightsTensor);
|
||||
TF_LITE_ENSURE(context, filter != nullptr);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* bias =
|
||||
GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
||||
"Hybrid models are not supported on TFLite Micro.");
|
||||
|
||||
return CalculateOpData(context, params->activation, input->type, input,
|
||||
filter, bias, output, data);
|
||||
}
|
||||
|
||||
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
const OpData& data,
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter,
|
||||
const TfLiteEvalTensor* bias,
|
||||
TfLiteEvalTensor* output) {
|
||||
tflite::FullyConnectedParams op_params;
|
||||
op_params.input_offset = -data.input_zero_point;
|
||||
op_params.weights_offset = -data.filter_zero_point;
|
||||
op_params.output_offset = data.output_zero_point;
|
||||
op_params.output_multiplier = data.output_multiplier;
|
||||
// TODO(b/138810107): Figure out whether output shift should be inverted
|
||||
op_params.output_shift = -data.output_shift;
|
||||
op_params.quantized_activation_min = data.output_activation_min;
|
||||
op_params.quantized_activation_max = data.output_activation_max;
|
||||
|
||||
reference_integer_ops::FullyConnected(
|
||||
op_params, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
const OpData& data, const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter,
|
||||
const TfLiteEvalTensor* bias,
|
||||
TfLiteEvalTensor* output) {
|
||||
const int32_t input_offset = -data.input_zero_point;
|
||||
const int32_t filter_offset = -data.filter_zero_point;
|
||||
const int32_t output_offset = data.output_zero_point;
|
||||
|
||||
tflite::FullyConnectedParams op_params;
|
||||
op_params.input_offset = input_offset;
|
||||
op_params.weights_offset = filter_offset;
|
||||
op_params.output_offset = output_offset;
|
||||
op_params.output_multiplier = data.output_multiplier;
|
||||
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
|
||||
op_params.output_shift = -data.output_shift;
|
||||
op_params.quantized_activation_min = data.output_activation_min;
|
||||
op_params.quantized_activation_max = data.output_activation_max;
|
||||
|
||||
#define TF_LITE_FULLY_CONNECTED(output_data_type) \
|
||||
reference_ops::FullyConnected( \
|
||||
op_params, tflite::micro::GetTensorShape(input), \
|
||||
tflite::micro::GetTensorData<uint8_t>(input), \
|
||||
tflite::micro::GetTensorShape(filter), \
|
||||
tflite::micro::GetTensorData<uint8_t>(filter), \
|
||||
tflite::micro::GetTensorShape(bias), \
|
||||
tflite::micro::GetTensorData<int32_t>(bias), \
|
||||
tflite::micro::GetTensorShape(output), \
|
||||
tflite::micro::GetTensorData<output_data_type>(output))
|
||||
switch (output->type) {
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_FULLY_CONNECTED(uint8_t);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
TF_LITE_FULLY_CONNECTED(int16_t);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(output->type), output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteFusedActivation activation,
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter,
|
||||
const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) {
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
tflite::FullyConnectedParams op_params;
|
||||
op_params.float_activation_min = output_activation_min;
|
||||
op_params.float_activation_max = output_activation_max;
|
||||
tflite::reference_ops::FullyConnected(
|
||||
op_params, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<float>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<float>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
return CalculateOpDataFullyConnected(context, params->activation, input->type,
|
||||
input, filter, bias, output, data);
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
@ -206,33 +67,66 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kWeightsTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
tflite::micro::GetEvalInput(context, node, kBiasTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
const auto& data =
|
||||
*(static_cast<const OpDataFullyConnected*>(node->user_data));
|
||||
|
||||
// Checks in Prepare ensure input, output and filter types are all the same.
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
return EvalFloat(context, node, params->activation, input, filter, bias,
|
||||
output);
|
||||
case kTfLiteInt8:
|
||||
return EvalQuantizedInt8(context, node, data, input, filter, bias,
|
||||
output);
|
||||
case kTfLiteFloat32: {
|
||||
tflite::reference_ops::FullyConnected(
|
||||
FullyConnectedParamsFloat(params->activation),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<float>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<float>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output));
|
||||
break;
|
||||
}
|
||||
|
||||
case kTfLiteUInt8:
|
||||
return EvalQuantized(context, node, data, input, filter, bias, output);
|
||||
case kTfLiteInt8: {
|
||||
tflite::reference_integer_ops::FullyConnected(
|
||||
FullyConnectedParamsQuantized(data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
case kTfLiteUInt8: {
|
||||
tflite::reference_ops::FullyConnected(
|
||||
FullyConnectedParamsQuantized(data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<uint8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<uint8_t>(filter),
|
||||
tflite::micro::GetTensorShape(bias),
|
||||
tflite::micro::GetTensorData<int32_t>(bias),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<uint8_t>(output));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
TfLiteTypeGetName(input->type), input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -15,10 +15,51 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
||||
#define TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
struct OpDataFullyConnected {
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
// uint8_t these would be 0 and 255.
|
||||
int32_t output_activation_min;
|
||||
int32_t output_activation_max;
|
||||
// The index of the temporary tensor where the quantized inputs are cached.
|
||||
int input_quantized_index;
|
||||
// Cached zero point values of tensors.
|
||||
int32_t input_zero_point;
|
||||
int32_t filter_zero_point;
|
||||
int32_t output_zero_point;
|
||||
};
|
||||
|
||||
extern const int kFullyConnectedInputTensor;
|
||||
extern const int kFullyConnectedWeightsTensor;
|
||||
extern const int kFullyConnectedBiasTensor;
|
||||
extern const int kFullyConnectedOutputTensor;
|
||||
|
||||
// Returns a FullyConnectedParams struct with all the parameters needed for a
|
||||
// float computation.
|
||||
FullyConnectedParams FullyConnectedParamsFloat(
|
||||
TfLiteFusedActivation activation);
|
||||
|
||||
// Returns a FullyConnectedParams struct with all the parameters needed for a
|
||||
// quantized computation.
|
||||
FullyConnectedParams FullyConnectedParamsQuantized(
|
||||
const OpDataFullyConnected& op_data);
|
||||
|
||||
TfLiteStatus CalculateOpDataFullyConnected(
|
||||
TfLiteContext* context, TfLiteFusedActivation activation,
|
||||
TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output, OpDataFullyConnected* data);
|
||||
|
||||
// This is the most generic TfLiteRegistration. The actual supported types may
|
||||
// still be target dependent. The only requirement is that every implementation
|
||||
// (reference or optimized) must define this function.
|
||||
|
||||
78
tensorflow/lite/micro/kernels/fully_connected_common.cc
Normal file
78
tensorflow/lite/micro/kernels/fully_connected_common.cc
Normal file
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
const int kFullyConnectedInputTensor = 0;
|
||||
const int kFullyConnectedWeightsTensor = 1;
|
||||
const int kFullyConnectedBiasTensor = 2;
|
||||
const int kFullyConnectedOutputTensor = 0;
|
||||
|
||||
FullyConnectedParams FullyConnectedParamsQuantized(
|
||||
const OpDataFullyConnected& op_data) {
|
||||
FullyConnectedParams op_params;
|
||||
op_params.input_offset = -op_data.input_zero_point;
|
||||
op_params.weights_offset = -op_data.filter_zero_point;
|
||||
op_params.output_offset = op_data.output_zero_point;
|
||||
op_params.output_multiplier = op_data.output_multiplier;
|
||||
op_params.output_shift = op_data.output_shift;
|
||||
op_params.quantized_activation_min = op_data.output_activation_min;
|
||||
op_params.quantized_activation_max = op_data.output_activation_max;
|
||||
return op_params;
|
||||
}
|
||||
|
||||
FullyConnectedParams FullyConnectedParamsFloat(
|
||||
TfLiteFusedActivation activation) {
|
||||
FullyConnectedParams op_params;
|
||||
CalculateActivationRange(activation, &op_params.float_activation_min,
|
||||
&op_params.float_activation_max);
|
||||
return op_params;
|
||||
}
|
||||
|
||||
TfLiteStatus CalculateOpDataFullyConnected(
|
||||
TfLiteContext* context, TfLiteFusedActivation activation,
|
||||
TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||
OpDataFullyConnected* data) {
|
||||
if (data_type != kTfLiteFloat32) {
|
||||
double real_multiplier = 0.0;
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||
context, input, filter, bias, output, &real_multiplier));
|
||||
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
|
||||
&data->output_shift);
|
||||
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
data->filter_zero_point = filter->params.zero_point;
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
|
||||
return CalculateActivationRangeQuantized(context, activation, output,
|
||||
&data->output_activation_min,
|
||||
&data->output_activation_max);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/micro/kernels/fully_connected.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
@ -29,30 +30,6 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
struct OpData {
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
// be represented as a fixed point multiplier plus a left shift.
|
||||
int32_t output_multiplier;
|
||||
int output_shift;
|
||||
|
||||
// Cached tensor zero point values for quantized operations.
|
||||
int32_t input_zero_point;
|
||||
int32_t filter_zero_point;
|
||||
int32_t output_zero_point;
|
||||
|
||||
// The range of the fused activation layer. For example for kNone and
|
||||
// uint8_t these would be 0 and 255.
|
||||
int32_t output_activation_min;
|
||||
int32_t output_activation_max;
|
||||
// The index of the temporary tensor where the quantized inputs are cached.
|
||||
int input_quantized_index;
|
||||
};
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kWeightsTensor = 1;
|
||||
constexpr int kBiasTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
#if defined(HIFIMINI)
|
||||
void FullyConnected(const FullyConnectedParams& params,
|
||||
const RuntimeShape& input_shape, const int8_t* input_data,
|
||||
@ -144,7 +121,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
TfLiteType data_type, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||
OpData* data) {
|
||||
OpDataFullyConnected* data) {
|
||||
double real_multiplier = 0.0;
|
||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||
context, input, filter, bias, output, &real_multiplier));
|
||||
@ -155,6 +132,10 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
QuantizeMultiplier(real_multiplier, &data->output_multiplier,
|
||||
&data->output_shift);
|
||||
#endif
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
data->filter_zero_point = filter->params.zero_point;
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
|
||||
return CalculateActivationRangeQuantized(context, activation, output,
|
||||
&data->output_activation_min,
|
||||
&data->output_activation_max);
|
||||
@ -162,21 +143,25 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
return context->AllocatePersistentBuffer(context,
|
||||
sizeof(OpDataFullyConnected));
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
auto* data = static_cast<OpDataFullyConnected*>(node->user_data);
|
||||
const auto* params =
|
||||
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteTensor* input =
|
||||
GetInput(context, node, kFullyConnectedInputTensor);
|
||||
const TfLiteTensor* filter =
|
||||
GetInput(context, node, kFullyConnectedWeightsTensor);
|
||||
const TfLiteTensor* bias =
|
||||
GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor);
|
||||
|
||||
if (input->type != kTfLiteInt8) {
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
@ -184,36 +169,26 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
data->filter_zero_point = filter->params.zero_point;
|
||||
data->output_zero_point = output->params.zero_point;
|
||||
|
||||
return CalculateOpData(context, params->activation, input->type, input,
|
||||
filter, bias, output, data);
|
||||
}
|
||||
|
||||
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
const OpData& data,
|
||||
const OpDataFullyConnected& data,
|
||||
const TfLiteEvalTensor* input,
|
||||
const TfLiteEvalTensor* filter,
|
||||
const TfLiteEvalTensor* bias,
|
||||
TfLiteEvalTensor* output) {
|
||||
// TODO(b/154032858): Investigate removing extra copies, and also passing by
|
||||
// value. TODO(b/155656675): Consider passing OpData by value once it is also
|
||||
// passed to the FullyConnected function. Until it is copied to a local
|
||||
// op_param variable, we do not get any latency improvements from passing by
|
||||
// value.
|
||||
FullyConnectedParams op_params;
|
||||
op_params.input_offset = -data.input_zero_point;
|
||||
op_params.weights_offset = -data.filter_zero_point;
|
||||
op_params.output_offset = data.output_zero_point;
|
||||
op_params.output_multiplier = data.output_multiplier;
|
||||
op_params.output_shift = data.output_shift;
|
||||
op_params.quantized_activation_min = data.output_activation_min;
|
||||
op_params.quantized_activation_max = data.output_activation_max;
|
||||
|
||||
// TODO(b/154032858): Investigate removing extra copies (i.e.
|
||||
// data.ToQuantizedParams), and also passing by value.
|
||||
//
|
||||
// TODO(b/155656675): Consider passing OpDataFullyConnected by value
|
||||
// once it is also passed to the FullyConnected function. Until it is copied
|
||||
// to a local op_param variable, we do not get any latency improvements from
|
||||
// passing by value.
|
||||
#if defined(HIFIMINI)
|
||||
FullyConnected(op_params, tflite::micro::GetTensorShape(input),
|
||||
FullyConnected(FullyConnectedParamsQuantized(data),
|
||||
tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
@ -223,7 +198,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
#else
|
||||
reference_integer_ops::FullyConnected(
|
||||
op_params, tflite::micro::GetTensorShape(input),
|
||||
FullyConnectedParamsQuantized(data), tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(filter),
|
||||
tflite::micro::GetTensorData<int8_t>(filter),
|
||||
@ -238,19 +213,20 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
const auto& data =
|
||||
*(static_cast<const OpDataFullyConnected*>(node->user_data));
|
||||
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor);
|
||||
const TfLiteEvalTensor* filter =
|
||||
tflite::micro::GetEvalInput(context, node, kWeightsTensor);
|
||||
tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor);
|
||||
const TfLiteEvalTensor* bias =
|
||||
(NumInputs(node) == 3)
|
||||
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
|
||||
: nullptr;
|
||||
(NumInputs(node) == 3) ? tflite::micro::GetEvalInput(
|
||||
context, node, kFullyConnectedBiasTensor)
|
||||
: nullptr;
|
||||
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);
|
||||
|
||||
return EvalQuantizedInt8(context, node, data, input, filter, bias, output);
|
||||
}
|
||||
|
||||
@ -329,6 +329,7 @@ tensorflow/lite/micro/kernels/ethosu.cc \
|
||||
tensorflow/lite/micro/kernels/flexbuffers_generated_data.cc \
|
||||
tensorflow/lite/micro/kernels/floor.cc \
|
||||
tensorflow/lite/micro/kernels/fully_connected.cc \
|
||||
tensorflow/lite/micro/kernels/fully_connected_common.cc \
|
||||
tensorflow/lite/micro/kernels/hard_swish.cc \
|
||||
tensorflow/lite/micro/kernels/kernel_runner.cc \
|
||||
tensorflow/lite/micro/kernels/kernel_util.cc \
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user