Small Softmax cleanups:

- Remove OpData. Use SoftmaxParams directly.
- Only call CalculateSoftmaxOpData for quantized case, rename to CalculateSoftmaxParams.
- Add stricter type checks to CalculateSoftmaxParams.
- Use static_cast instead of reinterpret_cast

PiperOrigin-RevId: 303175991
Change-Id: I5a1e746d53ff7758c5e31535cce2961e71ce8fb4
This commit is contained in:
Robert David 2020-03-26 12:55:10 -07:00 committed by TensorFlower Gardener
parent e8590130e3
commit 3dc8f81602
3 changed files with 80 additions and 98 deletions

View File

@ -25,35 +25,37 @@ namespace micro {
namespace activations { namespace activations {
namespace { namespace {
struct OpData { TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
int32_t input_multiplier = 0;
int input_left_shift = 0;
int32_t input_range_radius = 0;
int diff_min = 0;
};
TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
const TfLiteTensor* input, const TfLiteTensor* input,
TfLiteTensor* output, TfLiteTensor* output,
const TfLiteSoftmaxParams* params, const TfLiteSoftmaxParams* params,
OpData* data) { SoftmaxParams* op_data) {
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
if (input->type == kTfLiteUInt8) { if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
} else { } else {
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
} }
TF_LITE_ENSURE(context, (output->params.scale == 1.f / 256) || TF_LITE_ENSURE(context, (output->params.scale == 1.f / 256) ||
(output->params.scale == 1.f / 255)); (output->params.scale == 1.f / 255));
static const int kScaledDiffIntegerBits = 5; static const int kScaledDiffIntegerBits = 5;
int input_left_shift;
tflite::PreprocessSoftmaxScaling( tflite::PreprocessSoftmaxScaling(
params->beta, input->params.scale, kScaledDiffIntegerBits, params->beta, input->params.scale, kScaledDiffIntegerBits,
&data->input_multiplier, &data->input_left_shift); &op_data->input_multiplier, &input_left_shift);
data->diff_min = -1.0 * tflite::CalculateInputRadius( op_data->input_left_shift = input_left_shift;
kScaledDiffIntegerBits, data->input_left_shift); op_data->diff_min =
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
op_data->input_left_shift);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
op_data->beta = static_cast<double>(params->beta);
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -75,26 +77,19 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk; return kTfLiteOk;
} }
// Takes a 4D tensor and perform softmax along the forth dimension. // Takes a tensor and performs softmax along the last dimension.
void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output, void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) { const SoftmaxParams& op_data) {
SoftmaxParams op_params;
op_params.beta = params->beta;
tflite::reference_ops::Softmax( tflite::reference_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<float>(input), op_data, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output)); GetTensorShape(output), GetTensorData<float>(output));
} }
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) { const SoftmaxParams& op_data) {
SoftmaxParams op_params;
op_params.input_multiplier = data->input_multiplier;
op_params.input_left_shift = data->input_left_shift;
op_params.diff_min = data->diff_min;
if (input->type == kTfLiteUInt8) { if (input->type == kTfLiteUInt8) {
tflite::reference_ops::Softmax( tflite::reference_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), op_data, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output)); GetTensorShape(output), GetTensorData<uint8_t>(output));
} else { } else {
const unsigned int num_dims = NumDimensions(input); const unsigned int num_dims = NumDimensions(input);
@ -106,30 +101,29 @@ void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
arm_softmax_s8(GetTensorData<int8_t>(input), outer_size, depth, arm_softmax_s8(GetTensorData<int8_t>(input), outer_size, depth,
op_params.input_multiplier, op_params.input_left_shift, op_data.input_multiplier, op_data.input_left_shift,
op_params.diff_min, GetTensorData<int8_t>(output)); op_data.diff_min, GetTensorData<int8_t>(output));
} }
} }
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data); auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0);
OpData local_data_object; SoftmaxParams op_data;
OpData* data = &local_data_object;
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
CalculateSoftmaxOpData(context, input, output, params, data)); CalculateSoftmaxParams(context, input, output, params, &op_data));
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
SoftmaxFloat(input, output, params); SoftmaxFloat(input, output, op_data);
return kTfLiteOk; return kTfLiteOk;
} }
case kTfLiteUInt8: case kTfLiteInt8:
case kTfLiteInt8: { case kTfLiteUInt8: {
SoftmaxQuantized(input, output, params, data); SoftmaxQuantized(input, output, params, op_data);
return kTfLiteOk; return kTfLiteOk;
} }
default: default:

View File

@ -29,27 +29,23 @@ namespace micro {
namespace activations { namespace activations {
namespace { namespace {
struct OpData { TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
int32_t input_multiplier = 0;
int input_left_shift = 0;
int32_t input_range_radius = 0;
int diff_min = 0;
};
TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
const TfLiteTensor* input, const TfLiteTensor* input,
TfLiteTensor* output, TfLiteTensor* output,
const TfLiteSoftmaxParams* params, const TfLiteSoftmaxParams* params,
OpData* data) { SoftmaxParams* op_data) {
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
if (input->type == kTfLiteUInt8) { if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
} else { } else {
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
if (output->type == kTfLiteInt16) { if (output->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768); TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
// NOTE: Current int16 softmax output does not require symmetric scaling // NOTE: Current int16 softmax output does not require symmetric scaling
// - so no need to verify scale here. // - so no need to verify scale here.
} else { } else {
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
} }
@ -57,12 +53,19 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
static const int kScaledDiffIntegerBits = 5; static const int kScaledDiffIntegerBits = 5;
int input_left_shift;
tflite::PreprocessSoftmaxScaling( tflite::PreprocessSoftmaxScaling(
static_cast<double>(params->beta), static_cast<double>(params->beta),
static_cast<double>(input->params.scale), kScaledDiffIntegerBits, static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
&data->input_multiplier, &data->input_left_shift); &op_data->input_multiplier, &input_left_shift);
data->diff_min = -1.0 * tflite::CalculateInputRadius( op_data->input_left_shift = input_left_shift;
kScaledDiffIntegerBits, data->input_left_shift); op_data->diff_min =
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
op_data->input_left_shift);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
op_data->beta = static_cast<double>(params->beta);
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -86,56 +89,49 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
// Takes a tensor and performs softmax along the last dimension. // Takes a tensor and performs softmax along the last dimension.
void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output, void SoftmaxFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) { const SoftmaxParams& op_data) {
SoftmaxParams op_params;
op_params.beta = static_cast<double>(params->beta);
tflite::reference_ops::Softmax( tflite::reference_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<float>(input), op_data, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output)); GetTensorShape(output), GetTensorData<float>(output));
} }
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) { const SoftmaxParams& op_data) {
SoftmaxParams op_params;
op_params.input_multiplier = data->input_multiplier;
op_params.input_left_shift = data->input_left_shift;
op_params.diff_min = data->diff_min;
if (input->type == kTfLiteUInt8) { if (input->type == kTfLiteUInt8) {
tflite::reference_ops::Softmax( tflite::reference_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), op_data, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output)); GetTensorShape(output), GetTensorData<uint8_t>(output));
} else { } else {
if (output->type == kTfLiteInt16) { if (output->type == kTfLiteInt16) {
tflite::reference_ops::Softmax( tflite::reference_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input), op_data, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(output), GetTensorData<int16_t>(output)); GetTensorShape(output), GetTensorData<int16_t>(output));
} else { } else {
tflite::reference_ops::Softmax( tflite::reference_ops::Softmax(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input), op_data, GetTensorShape(input), GetTensorData<int8_t>(input),
GetTensorShape(output), GetTensorData<int8_t>(output)); GetTensorShape(output), GetTensorData<int8_t>(output));
} }
} }
} }
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data); auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0);
OpData local_data_object; SoftmaxParams op_data;
OpData* data = &local_data_object;
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
CalculateSoftmaxOpData(context, input, output, params, data)); CalculateSoftmaxParams(context, input, output, params, &op_data));
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
SoftmaxFloat(input, output, params); SoftmaxFloat(input, output, op_data);
return kTfLiteOk; return kTfLiteOk;
} }
case kTfLiteInt8: case kTfLiteInt8:
case kTfLiteUInt8: { case kTfLiteUInt8: {
SoftmaxQuantized(input, output, params, data); SoftmaxQuantized(input, output, op_data);
return kTfLiteOk; return kTfLiteOk;
} }
default: default:
@ -149,11 +145,14 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace activations } // namespace activations
TfLiteRegistration* Register_SOFTMAX() { TfLiteRegistration* Register_SOFTMAX() {
static TfLiteRegistration r = {}; static TfLiteRegistration r = {activations::Init,
r.init = activations::Init; activations::Free,
r.free = activations::Free; activations::SoftmaxPrepare,
r.prepare = activations::SoftmaxPrepare; activations::SoftmaxEval,
r.invoke = activations::SoftmaxEval; nullptr,
0,
nullptr,
0};
return &r; return &r;
} }

View File

@ -117,21 +117,14 @@ inline void Softmax(const SoftmaxParams& params,
namespace activations { namespace activations {
namespace { namespace {
struct OpData {
int32_t input_multiplier = 0;
int input_left_shift = 0;
int32_t input_range_radius = 0;
int diff_min = 0;
};
// This size will work for both the hotword (1) and ambient music (0): // This size will work for both the hotword (1) and ambient music (0):
static OpData kStaticOpData; static SoftmaxParams kStaticOpData;
TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
const TfLiteTensor* input, const TfLiteTensor* input,
TfLiteTensor* output, TfLiteTensor* output,
const TfLiteSoftmaxParams* params, const TfLiteSoftmaxParams* params,
OpData* data) { SoftmaxParams* op_data) {
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
if (input->type == kTfLiteUInt8) { if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
@ -148,12 +141,14 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
static const int kScaledDiffIntegerBits = 5; static const int kScaledDiffIntegerBits = 5;
int input_left_shift;
tflite::PreprocessSoftmaxScaling( tflite::PreprocessSoftmaxScaling(
static_cast<double>(params->beta), params->beta, input->params.scale, kScaledDiffIntegerBits,
static_cast<double>(input->params.scale), kScaledDiffIntegerBits, &op_data->input_multiplier, &input_left_shift);
&data->input_multiplier, &data->input_left_shift); op_data->input_left_shift = input_left_shift;
data->diff_min = -1.0 * tflite::CalculateInputRadius( op_data->diff_min =
kScaledDiffIntegerBits, data->input_left_shift); -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
op_data->input_left_shift);
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -161,12 +156,7 @@ TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
} // namespace } // namespace
void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output, void SoftmaxQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) { const SoftmaxParams& op_params) {
SoftmaxParams op_params;
op_params.input_multiplier = data->input_multiplier;
op_params.input_left_shift = data->input_left_shift;
op_params.diff_min = data->diff_min;
if (output->type == kTfLiteInt16) { if (output->type == kTfLiteInt16) {
xtensa::hifimini::Softmax( xtensa::hifimini::Softmax(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input), op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
@ -186,7 +176,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
void Free(TfLiteContext* context, void* buffer) {} void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data); auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@ -194,27 +184,26 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, NumDimensions(input) >= 1); TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
// TODO(b/132070898): Use statically slotted OpData structures until a // TODO(b/132070898): Use statically slotted SoftmaxParams structures until a
// scratch memory API is ready. // scratch memory API is ready.
OpData* op_data = &kStaticOpData; SoftmaxParams* op_params = &kStaticOpData;
node->user_data = op_data; node->user_data = op_params;
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
CalculateSoftmaxOpData(context, input, output, params, op_data)); CalculateSoftmaxOpData(context, input, output, params, op_params));
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data); auto* op_params = static_cast<SoftmaxParams*>(node->user_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) { switch (input->type) {
case kTfLiteInt8: { case kTfLiteInt8: {
SoftmaxQuantized(input, output, params, op_data); SoftmaxQuantized(input, output, *op_params);
return kTfLiteOk; return kTfLiteOk;
} }
default: default: