Small Softmax cleanups:
- Remove OpData. Use SoftmaxParams directly. - Only call CalculateSoftmaxOpData for quantized case, rename to CalculateSoftmaxParams. - Add stricter type checks to CalculateSoftmaxParams. - Use static_cast instead of reinterpret_cast PiperOrigin-RevId: 303175991 Change-Id: I5a1e746d53ff7758c5e31535cce2961e71ce8fb4
This commit is contained in:
parent
e8590130e3
commit
3dc8f81602
@ -25,35 +25,37 @@ namespace micro {
|
|||||||
namespace activations {
|
namespace activations {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct OpData {
|
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
||||||
int32_t input_multiplier = 0;
|
|
||||||
int input_left_shift = 0;
|
|
||||||
int32_t input_range_radius = 0;
|
|
||||||
int diff_min = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
|
|
||||||
const TfLiteTensor* input,
|
const TfLiteTensor* input,
|
||||||
TfLiteTensor* output,
|
TfLiteTensor* output,
|
||||||
const TfLiteSoftmaxParams* params,
|
const TfLiteSoftmaxParams* params,
|
||||||
OpData* data) {
|
SoftmaxParams* op_data) {
|
||||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||||
if (input->type == kTfLiteUInt8) {
|
if (input->type == kTfLiteUInt8) {
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||||
} else {
|
} else {
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_ENSURE(context, (output->params.scale == 1.f / 256) ||
|
TF_LITE_ENSURE(context, (output->params.scale == 1.f / 256) ||
|
||||||
(output->params.scale == 1.f / 255));
|
(output->params.scale == 1.f / 255));
|
||||||
|
|
||||||
static const int kScaledDiffIntegerBits = 5;
|
static const int kScaledDiffIntegerBits = 5;
|
||||||
|
|
||||||
|
int input_left_shift;
|
||||||
tflite::PreprocessSoftmaxScaling(
|
tflite::PreprocessSoftmaxScaling(
|
||||||
params->beta, input->params.scale, kScaledDiffIntegerBits,
|
params->beta, input->params.scale, kScaledDiffIntegerBits,
|
||||||
&data->input_multiplier, &data->input_left_shift);
|
&op_data->input_multiplier, &input_left_shift);
|
||||||
data->diff_min = -1.0 * tflite::CalculateInputRadius(
|
op_data->input_left_shift = input_left_shift;
|
||||||
kScaledDiffIntegerBits, data->input_left_shift);
|
op_data->diff_min =
|
||||||
|
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
|
||||||
|
op_data->input_left_shift);
|
||||||
|
} else {
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
||||||
|
op_data->beta = static_cast<double>(params->beta);
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -75,26 +77,19 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Takes a 4D tensor and perform softmax along the forth dimension.
|
// Takes a tensor and performs softmax along the last dimension.
|
||||||
void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output,
|
void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
TfLiteSoftmaxParams* params) {
|
const SoftmaxParams& op_data) {
|
||||||
SoftmaxParams op_params;
|
|
||||||
op_params.beta = params->beta;
|
|
||||||
tflite::reference_ops::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
op_data, GetTensorShape(input), GetTensorData<float>(input),
|
||||||
GetTensorShape(output), GetTensorData<float>(output));
|
GetTensorShape(output), GetTensorData<float>(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
TfLiteSoftmaxParams* params, OpData* data) {
|
const SoftmaxParams& op_data) {
|
||||||
SoftmaxParams op_params;
|
|
||||||
op_params.input_multiplier = data->input_multiplier;
|
|
||||||
op_params.input_left_shift = data->input_left_shift;
|
|
||||||
op_params.diff_min = data->diff_min;
|
|
||||||
|
|
||||||
if (input->type == kTfLiteUInt8) {
|
if (input->type == kTfLiteUInt8) {
|
||||||
tflite::reference_ops::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
op_data, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||||
} else {
|
} else {
|
||||||
const unsigned int num_dims = NumDimensions(input);
|
const unsigned int num_dims = NumDimensions(input);
|
||||||
@ -106,30 +101,29 @@ void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
|||||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||||
|
|
||||||
arm_softmax_s8(GetTensorData<int8_t>(input), outer_size, depth,
|
arm_softmax_s8(GetTensorData<int8_t>(input), outer_size, depth,
|
||||||
op_params.input_multiplier, op_params.input_left_shift,
|
op_data.input_multiplier, op_data.input_left_shift,
|
||||||
op_params.diff_min, GetTensorData<int8_t>(output));
|
op_data.diff_min, GetTensorData<int8_t>(output));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||||
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||||
|
|
||||||
OpData local_data_object;
|
SoftmaxParams op_data;
|
||||||
OpData* data = &local_data_object;
|
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(
|
||||||
CalculateSoftmaxOpData(context, input, output, params, data));
|
CalculateSoftmaxParams(context, input, output, params, &op_data));
|
||||||
|
|
||||||
switch (input->type) {
|
switch (input->type) {
|
||||||
case kTfLiteFloat32: {
|
case kTfLiteFloat32: {
|
||||||
SoftmaxFloat(input, output, params);
|
SoftmaxFloat(input, output, op_data);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
case kTfLiteUInt8:
|
case kTfLiteInt8:
|
||||||
case kTfLiteInt8: {
|
case kTfLiteUInt8: {
|
||||||
SoftmaxQuantized(input, output, params, data);
|
SoftmaxQuantized(input, output, params, op_data);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -29,27 +29,23 @@ namespace micro {
|
|||||||
namespace activations {
|
namespace activations {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct OpData {
|
TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
|
||||||
int32_t input_multiplier = 0;
|
|
||||||
int input_left_shift = 0;
|
|
||||||
int32_t input_range_radius = 0;
|
|
||||||
int diff_min = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
|
|
||||||
const TfLiteTensor* input,
|
const TfLiteTensor* input,
|
||||||
TfLiteTensor* output,
|
TfLiteTensor* output,
|
||||||
const TfLiteSoftmaxParams* params,
|
const TfLiteSoftmaxParams* params,
|
||||||
OpData* data) {
|
SoftmaxParams* op_data) {
|
||||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||||
if (input->type == kTfLiteUInt8) {
|
if (input->type == kTfLiteUInt8) {
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||||
} else {
|
} else {
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
||||||
if (output->type == kTfLiteInt16) {
|
if (output->type == kTfLiteInt16) {
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
|
||||||
// NOTE: Current int16 softmax output does not require symmetric scaling
|
// NOTE: Current int16 softmax output does not require symmetric scaling
|
||||||
// - so no need to verify scale here.
|
// - so no need to verify scale here.
|
||||||
} else {
|
} else {
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
|
||||||
TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
|
TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
|
||||||
}
|
}
|
||||||
@ -57,12 +53,19 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
|
|||||||
|
|
||||||
static const int kScaledDiffIntegerBits = 5;
|
static const int kScaledDiffIntegerBits = 5;
|
||||||
|
|
||||||
|
int input_left_shift;
|
||||||
tflite::PreprocessSoftmaxScaling(
|
tflite::PreprocessSoftmaxScaling(
|
||||||
static_cast<double>(params->beta),
|
static_cast<double>(params->beta),
|
||||||
static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
|
static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
|
||||||
&data->input_multiplier, &data->input_left_shift);
|
&op_data->input_multiplier, &input_left_shift);
|
||||||
data->diff_min = -1.0 * tflite::CalculateInputRadius(
|
op_data->input_left_shift = input_left_shift;
|
||||||
kScaledDiffIntegerBits, data->input_left_shift);
|
op_data->diff_min =
|
||||||
|
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
|
||||||
|
op_data->input_left_shift);
|
||||||
|
} else {
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
||||||
|
op_data->beta = static_cast<double>(params->beta);
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -86,56 +89,49 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Takes a tensor and performs softmax along the last dimension.
|
// Takes a tensor and performs softmax along the last dimension.
|
||||||
void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output,
|
void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
TfLiteSoftmaxParams* params) {
|
const SoftmaxParams& op_data) {
|
||||||
SoftmaxParams op_params;
|
|
||||||
op_params.beta = static_cast<double>(params->beta);
|
|
||||||
tflite::reference_ops::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
op_data, GetTensorShape(input), GetTensorData<float>(input),
|
||||||
GetTensorShape(output), GetTensorData<float>(output));
|
GetTensorShape(output), GetTensorData<float>(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
TfLiteSoftmaxParams* params, OpData* data) {
|
const SoftmaxParams& op_data) {
|
||||||
SoftmaxParams op_params;
|
|
||||||
op_params.input_multiplier = data->input_multiplier;
|
|
||||||
op_params.input_left_shift = data->input_left_shift;
|
|
||||||
op_params.diff_min = data->diff_min;
|
|
||||||
if (input->type == kTfLiteUInt8) {
|
if (input->type == kTfLiteUInt8) {
|
||||||
tflite::reference_ops::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
op_data, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||||
} else {
|
} else {
|
||||||
if (output->type == kTfLiteInt16) {
|
if (output->type == kTfLiteInt16) {
|
||||||
tflite::reference_ops::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
op_data, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||||
} else {
|
} else {
|
||||||
tflite::reference_ops::Softmax(
|
tflite::reference_ops::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
op_data, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||||
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||||
|
|
||||||
OpData local_data_object;
|
SoftmaxParams op_data;
|
||||||
OpData* data = &local_data_object;
|
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(
|
||||||
CalculateSoftmaxOpData(context, input, output, params, data));
|
CalculateSoftmaxParams(context, input, output, params, &op_data));
|
||||||
|
|
||||||
switch (input->type) {
|
switch (input->type) {
|
||||||
case kTfLiteFloat32: {
|
case kTfLiteFloat32: {
|
||||||
SoftmaxFloat(input, output, params);
|
SoftmaxFloat(input, output, op_data);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
case kTfLiteUInt8: {
|
case kTfLiteUInt8: {
|
||||||
SoftmaxQuantized(input, output, params, data);
|
SoftmaxQuantized(input, output, op_data);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -149,11 +145,14 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace activations
|
} // namespace activations
|
||||||
|
|
||||||
TfLiteRegistration* Register_SOFTMAX() {
|
TfLiteRegistration* Register_SOFTMAX() {
|
||||||
static TfLiteRegistration r = {};
|
static TfLiteRegistration r = {activations::Init,
|
||||||
r.init = activations::Init;
|
activations::Free,
|
||||||
r.free = activations::Free;
|
activations::SoftmaxPrepare,
|
||||||
r.prepare = activations::SoftmaxPrepare;
|
activations::SoftmaxEval,
|
||||||
r.invoke = activations::SoftmaxEval;
|
nullptr,
|
||||||
|
0,
|
||||||
|
nullptr,
|
||||||
|
0};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,21 +117,14 @@ inline void Softmax(const SoftmaxParams& params,
|
|||||||
namespace activations {
|
namespace activations {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct OpData {
|
|
||||||
int32_t input_multiplier = 0;
|
|
||||||
int input_left_shift = 0;
|
|
||||||
int32_t input_range_radius = 0;
|
|
||||||
int diff_min = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
// This size will work for both the hotword (1) and ambient music (0):
|
// This size will work for both the hotword (1) and ambient music (0):
|
||||||
static OpData kStaticOpData;
|
static SoftmaxParams kStaticOpData;
|
||||||
|
|
||||||
TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
|
TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
|
||||||
const TfLiteTensor* input,
|
const TfLiteTensor* input,
|
||||||
TfLiteTensor* output,
|
TfLiteTensor* output,
|
||||||
const TfLiteSoftmaxParams* params,
|
const TfLiteSoftmaxParams* params,
|
||||||
OpData* data) {
|
SoftmaxParams* op_data) {
|
||||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||||
if (input->type == kTfLiteUInt8) {
|
if (input->type == kTfLiteUInt8) {
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||||
@ -148,12 +141,14 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
|
|||||||
|
|
||||||
static const int kScaledDiffIntegerBits = 5;
|
static const int kScaledDiffIntegerBits = 5;
|
||||||
|
|
||||||
|
int input_left_shift;
|
||||||
tflite::PreprocessSoftmaxScaling(
|
tflite::PreprocessSoftmaxScaling(
|
||||||
static_cast<double>(params->beta),
|
params->beta, input->params.scale, kScaledDiffIntegerBits,
|
||||||
static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
|
&op_data->input_multiplier, &input_left_shift);
|
||||||
&data->input_multiplier, &data->input_left_shift);
|
op_data->input_left_shift = input_left_shift;
|
||||||
data->diff_min = -1.0 * tflite::CalculateInputRadius(
|
op_data->diff_min =
|
||||||
kScaledDiffIntegerBits, data->input_left_shift);
|
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
|
||||||
|
op_data->input_left_shift);
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -161,12 +156,7 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
TfLiteSoftmaxParams* params, OpData* data) {
|
const SoftmaxParams& op_params) {
|
||||||
SoftmaxParams op_params;
|
|
||||||
op_params.input_multiplier = data->input_multiplier;
|
|
||||||
op_params.input_left_shift = data->input_left_shift;
|
|
||||||
op_params.diff_min = data->diff_min;
|
|
||||||
|
|
||||||
if (output->type == kTfLiteInt16) {
|
if (output->type == kTfLiteInt16) {
|
||||||
xtensa::hifimini::Softmax(
|
xtensa::hifimini::Softmax(
|
||||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
@ -186,7 +176,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
void Free(TfLiteContext* context, void* buffer) {}
|
void Free(TfLiteContext* context, void* buffer) {}
|
||||||
|
|
||||||
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||||
|
|
||||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
@ -194,27 +184,26 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
|
||||||
|
|
||||||
// TODO(b/132070898): Use statically slotted OpData structures until a
|
// TODO(b/132070898): Use statically slotted SoftmaxParams structures until a
|
||||||
// scratch memory API is ready.
|
// scratch memory API is ready.
|
||||||
OpData* op_data = &kStaticOpData;
|
SoftmaxParams* op_params = &kStaticOpData;
|
||||||
node->user_data = op_data;
|
node->user_data = op_params;
|
||||||
|
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(
|
||||||
CalculateSoftmaxOpData(context, input, output, params, op_data));
|
CalculateSoftmaxOpData(context, input, output, params, op_params));
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
auto* op_params = static_cast<SoftmaxParams*>(node->user_data);
|
||||||
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
|
|
||||||
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||||
|
|
||||||
switch (input->type) {
|
switch (input->type) {
|
||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
SoftmaxQuantized(input, output, params, op_data);
|
SoftmaxQuantized(input, output, *op_params);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user