Create fixed point Softmax that uses asymmetric quantization with int8 as input and output.
PiperOrigin-RevId: 226401161
This commit is contained in:
parent
7a0c9559a9
commit
23178bd498
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
@ -50,6 +51,20 @@ struct PreluOpData : public OpData {
|
|||||||
int output_shift = 0;
|
int output_shift = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
TfLiteStatus CheckInputQuantParams(TfLiteContext* context,
|
||||||
|
const TfLiteTensor* input,
|
||||||
|
const TfLiteTensor* output) {
|
||||||
|
if (input->type == kTfLiteUInt8) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||||
|
TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
|
||||||
|
} else {
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
// This is a builtin op, so we don't use the contents in 'buffer', if any.
|
// This is a builtin op, so we don't use the contents in 'buffer', if any.
|
||||||
// Instead, we allocate a new object to carry information from Prepare() to
|
// Instead, we allocate a new object to carry information from Prepare() to
|
||||||
@ -215,12 +230,12 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
const int num_dims = NumDimensions(input);
|
const int num_dims = NumDimensions(input);
|
||||||
TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
|
TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
|
||||||
|
|
||||||
if (input->type == kTfLiteUInt8) {
|
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
if (CheckInputQuantParams(context, input, output) == kTfLiteError) {
|
||||||
TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
static const int kScaledDiffIntegerBits = 5;
|
static const int kScaledDiffIntegerBits = 5;
|
||||||
|
|
||||||
tflite::PreprocessSoftmaxScaling(
|
tflite::PreprocessSoftmaxScaling(
|
||||||
params->beta, input->params.scale, kScaledDiffIntegerBits,
|
params->beta, input->params.scale, kScaledDiffIntegerBits,
|
||||||
&data->input_multiplier, &data->input_left_shift);
|
&data->input_multiplier, &data->input_left_shift);
|
||||||
@ -505,8 +520,8 @@ void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output,
|
|||||||
GetTensorData<float>(output));
|
GetTensorData<float>(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
void Softmax1DQuantizedUint8(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
TfLiteSoftmaxParams* params, OpData* data) {
|
TfLiteSoftmaxParams* params, OpData* data) {
|
||||||
// TODO(ahentz): this is arguably a dirty trick. Since the implementation
|
// TODO(ahentz): this is arguably a dirty trick. Since the implementation
|
||||||
// always traverses the last dimension of a 4D tensor, we will pretend our 1D
|
// always traverses the last dimension of a 4D tensor, we will pretend our 1D
|
||||||
// tensor is 4D in a special way. We will convert a (Y) shape into a (1,
|
// tensor is 4D in a special way. We will convert a (Y) shape into a (1,
|
||||||
@ -521,8 +536,8 @@ void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
|||||||
GetTensorShape({1, 1, 1, input_size}),
|
GetTensorShape({1, 1, 1, input_size}),
|
||||||
GetTensorData<uint8_t>(output));
|
GetTensorData<uint8_t>(output));
|
||||||
}
|
}
|
||||||
void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
void Softmax2DQuantizedUint8(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
TfLiteSoftmaxParams* params, OpData* data) {
|
TfLiteSoftmaxParams* params, OpData* data) {
|
||||||
// TODO(ahentz): this is arguably a dirty trick. Since the implementation
|
// TODO(ahentz): this is arguably a dirty trick. Since the implementation
|
||||||
// always traverses the last dimension of a 4D tensor, we will pretend our 2D
|
// always traverses the last dimension of a 4D tensor, we will pretend our 2D
|
||||||
// tensor is 4D in a special way. We will convert a (X, Y) shape into a (X,
|
// tensor is 4D in a special way. We will convert a (X, Y) shape into a (X,
|
||||||
@ -540,8 +555,8 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
|||||||
GetTensorData<uint8_t>(output));
|
GetTensorData<uint8_t>(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
void Softmax3DQuantizedUint8(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
TfLiteSoftmaxParams* params, OpData* data) {
|
TfLiteSoftmaxParams* params, OpData* data) {
|
||||||
const int batch_size = input->dims->data[0];
|
const int batch_size = input->dims->data[0];
|
||||||
const int intermediate_size = input->dims->data[1];
|
const int intermediate_size = input->dims->data[1];
|
||||||
const int input_size = input->dims->data[2];
|
const int input_size = input->dims->data[2];
|
||||||
@ -566,8 +581,8 @@ void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
|
|||||||
GetTensorData<float>(output));
|
GetTensorData<float>(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
void Softmax4DQuantizedUint8(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
TfLiteSoftmaxParams* params, OpData* data) {
|
TfLiteSoftmaxParams* params, OpData* data) {
|
||||||
SoftmaxParams op_params;
|
SoftmaxParams op_params;
|
||||||
op_params.input_multiplier = data->input_multiplier;
|
op_params.input_multiplier = data->input_multiplier;
|
||||||
op_params.input_left_shift = data->input_left_shift;
|
op_params.input_left_shift = data->input_left_shift;
|
||||||
@ -577,6 +592,63 @@ void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
|
|||||||
GetTensorData<uint8_t>(output));
|
GetTensorData<uint8_t>(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(jianlijianli): Try merging Softmax<n>DQuantizedInt8 with
|
||||||
|
// Softmax<n>DQuantized, which needs a larger refactor.
|
||||||
|
void Softmax1DQuantizedInt8(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
|
TfLiteSoftmaxParams* params, OpData* data) {
|
||||||
|
const int input_size = input->dims->data[0];
|
||||||
|
SoftmaxParams op_params;
|
||||||
|
op_params.input_multiplier = data->input_multiplier;
|
||||||
|
op_params.input_left_shift = data->input_left_shift;
|
||||||
|
op_params.diff_min = data->diff_min;
|
||||||
|
reference_integer_ops::Softmax(
|
||||||
|
op_params, GetTensorShape({1, 1, 1, input_size}),
|
||||||
|
GetTensorData<int8_t>(input), GetTensorShape({1, 1, 1, input_size}),
|
||||||
|
GetTensorData<int8_t>(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Softmax2DQuantizedInt8(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
|
TfLiteSoftmaxParams* params, OpData* data) {
|
||||||
|
const int batch_size = input->dims->data[0];
|
||||||
|
const int input_size = input->dims->data[1];
|
||||||
|
SoftmaxParams op_params;
|
||||||
|
op_params.input_multiplier = data->input_multiplier;
|
||||||
|
op_params.input_left_shift = data->input_left_shift;
|
||||||
|
op_params.diff_min = data->diff_min;
|
||||||
|
reference_integer_ops::Softmax(op_params,
|
||||||
|
GetTensorShape({batch_size, 1, 1, input_size}),
|
||||||
|
GetTensorData<int8_t>(input),
|
||||||
|
GetTensorShape({batch_size, 1, 1, input_size}),
|
||||||
|
GetTensorData<int8_t>(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Softmax3DQuantizedInt8(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
|
TfLiteSoftmaxParams* params, OpData* data) {
|
||||||
|
const int batch_size = input->dims->data[0];
|
||||||
|
const int intermediate_size = input->dims->data[1];
|
||||||
|
const int input_size = input->dims->data[2];
|
||||||
|
SoftmaxParams op_params;
|
||||||
|
op_params.input_multiplier = data->input_multiplier;
|
||||||
|
op_params.input_left_shift = data->input_left_shift;
|
||||||
|
op_params.diff_min = data->diff_min;
|
||||||
|
reference_integer_ops::Softmax(
|
||||||
|
op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
|
||||||
|
GetTensorData<int8_t>(input),
|
||||||
|
GetTensorShape({batch_size, intermediate_size, 1, input_size}),
|
||||||
|
GetTensorData<int8_t>(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Softmax4DQuantizedInt8(const TfLiteTensor* input, TfLiteTensor* output,
|
||||||
|
TfLiteSoftmaxParams* params, OpData* data) {
|
||||||
|
SoftmaxParams op_params;
|
||||||
|
op_params.input_multiplier = data->input_multiplier;
|
||||||
|
op_params.input_left_shift = data->input_left_shift;
|
||||||
|
op_params.diff_min = data->diff_min;
|
||||||
|
reference_integer_ops::Softmax(
|
||||||
|
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||||
|
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
|
||||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
@ -611,19 +683,19 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
case kTfLiteUInt8: {
|
case kTfLiteUInt8: {
|
||||||
if (NumDimensions(input) == 1) {
|
if (NumDimensions(input) == 1) {
|
||||||
Softmax1DQuantized(input, output, params, data);
|
Softmax1DQuantizedUint8(input, output, params, data);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
if (NumDimensions(input) == 2) {
|
if (NumDimensions(input) == 2) {
|
||||||
Softmax2DQuantized(input, output, params, data);
|
Softmax2DQuantizedUint8(input, output, params, data);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
if (NumDimensions(input) == 3) {
|
if (NumDimensions(input) == 3) {
|
||||||
Softmax3DQuantized(input, output, params, data);
|
Softmax3DQuantizedUint8(input, output, params, data);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
if (NumDimensions(input) == 4) {
|
if (NumDimensions(input) == 4) {
|
||||||
Softmax4DQuantized(input, output, params, data);
|
Softmax4DQuantizedUint8(input, output, params, data);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
@ -631,6 +703,30 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
NumDimensions(input));
|
NumDimensions(input));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
case kTfLiteInt8: {
|
||||||
|
if (NumDimensions(input) == 1) {
|
||||||
|
Softmax1DQuantizedInt8(input, output, params, data);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
if (NumDimensions(input) == 2) {
|
||||||
|
Softmax2DQuantizedInt8(input, output, params, data);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
if (NumDimensions(input) == 3) {
|
||||||
|
Softmax3DQuantizedInt8(input, output, params, data);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
if (NumDimensions(input) == 4) {
|
||||||
|
Softmax4DQuantizedInt8(input, output, params, data);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
context->ReportError(
|
||||||
|
context,
|
||||||
|
"Only 4D tensors supported currently for Int8 kernels, got %dD.",
|
||||||
|
NumDimensions(input));
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
context, "Only float32 and uint8_t supported currently, got %s.",
|
context, "Only float32 and uint8_t supported currently, got %s.",
|
||||||
|
@ -44,6 +44,8 @@ class BaseActivationsOpModel : public SingleOpModel {
|
|||||||
input_ = AddInput(input);
|
input_ = AddInput(input);
|
||||||
if (input.type == TensorType_UINT8) {
|
if (input.type == TensorType_UINT8) {
|
||||||
output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
|
output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
|
||||||
|
} else if (input.type == TensorType_INT8) {
|
||||||
|
output_ = AddOutput({TensorType_INT8, {}, 0, 0, 1. / 256, -128});
|
||||||
} else {
|
} else {
|
||||||
output_ = AddOutput({input.type, {}});
|
output_ = AddOutput({input.type, {}});
|
||||||
}
|
}
|
||||||
@ -52,8 +54,8 @@ class BaseActivationsOpModel : public SingleOpModel {
|
|||||||
BuildInterpreter({GetShape(input_)});
|
BuildInterpreter({GetShape(input_)});
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseActivationsOpModel(BuiltinOperator type, const TensorData &input,
|
BaseActivationsOpModel(BuiltinOperator type, const TensorData& input,
|
||||||
const TensorData &output) {
|
const TensorData& output) {
|
||||||
input_ = AddInput(input);
|
input_ = AddInput(input);
|
||||||
output_ = AddOutput(output);
|
output_ = AddOutput(output);
|
||||||
SetBuiltinOp(type, BuiltinOptions_NONE, 0);
|
SetBuiltinOp(type, BuiltinOptions_NONE, 0);
|
||||||
@ -323,7 +325,7 @@ TEST(FloatActivationsOpTest, Softmax4D) {
|
|||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QuantizedActivationsOpTest, Softmax4D) {
|
TEST(QuantizedActivationsOpTest, Softmax4DUint8) {
|
||||||
QuantizedActivationsOpModel m(
|
QuantizedActivationsOpModel m(
|
||||||
0.1,
|
0.1,
|
||||||
/*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10});
|
/*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10});
|
||||||
@ -362,6 +364,145 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
|
|||||||
kQuantizedTolerance)));
|
kQuantizedTolerance)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test quantized softmax with int8 input and output. With the same input as in
|
||||||
|
// QuantizedActivationsOpTest.Softmax1D, the dequantized output is identical.
|
||||||
|
TEST(QuantizedActivationsOpTest, Softmax1DInt8) {
|
||||||
|
QuantizedActivationsOpModel m(0.1,
|
||||||
|
/*input=*/{TensorType_INT8, {8}, -10, 10});
|
||||||
|
m.SetInput<int8_t>({0, -6, 2, 4, 3, -2, 10, 1});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(
|
||||||
|
m.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453,
|
||||||
|
0.13281, 0.07813, 0.26563, 0.10938},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test quantized softmax with int8 input and output. With the same input as in
|
||||||
|
// QuantizedActivationsOpTest.Softmax2D, the dequantized output is identical.
|
||||||
|
TEST(QuantizedActivationsOpTest, Softmax2DInt8) {
|
||||||
|
QuantizedActivationsOpModel m(0.1,
|
||||||
|
/*input=*/{TensorType_INT8, {2, 4}, -10, 10});
|
||||||
|
m.SetInput<int8_t>({
|
||||||
|
0, -6, 2, 4, //
|
||||||
|
3, -2, 10, 1, //
|
||||||
|
});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
.23463, .12877, .28658, .35003, //
|
||||||
|
.22528, .13664, .45365, .18443, //
|
||||||
|
},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
|
||||||
|
// Same input, but a different shape.
|
||||||
|
QuantizedActivationsOpModel m2(0.1,
|
||||||
|
/*input=*/{TensorType_INT8, {4, 2}, -10, 10});
|
||||||
|
m2.SetInput<int8_t>({
|
||||||
|
0, -6, //
|
||||||
|
2, 4, //
|
||||||
|
3, -2, //
|
||||||
|
10, 1, //
|
||||||
|
});
|
||||||
|
m2.Invoke();
|
||||||
|
EXPECT_THAT(m2.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
0.645656, 0.354344, //
|
||||||
|
0.450166, 0.549834, //
|
||||||
|
0.622459, 0.377541, //
|
||||||
|
0.710949, 0.28905, //
|
||||||
|
},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test quantized softmax with int8 input and output. With the same input as in
|
||||||
|
// QuantizedActivationsOpTest.Softmax3D, the dequantized output is identical.
|
||||||
|
TEST(QuantizedActivationsOpTest, Softmax3DInt8) {
|
||||||
|
QuantizedActivationsOpModel m(
|
||||||
|
0.1,
|
||||||
|
/*input=*/{TensorType_INT8, {1, 2, 4}, -10, 10});
|
||||||
|
m.SetInput<int8_t>({
|
||||||
|
0, -6, 2, 4, // depth = 0
|
||||||
|
3, -2, 10, 1, // depth = 1
|
||||||
|
});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
.23463, .12877, .28658, .35003, //
|
||||||
|
.22528, .13664, .45365, .18443, //
|
||||||
|
},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
|
||||||
|
// Same input, but a different shape.
|
||||||
|
QuantizedActivationsOpModel m2(
|
||||||
|
0.1,
|
||||||
|
/*input=*/{TensorType_INT8, {4, 1, 2}, -10, 10});
|
||||||
|
m2.SetInput<int8_t>({
|
||||||
|
0, -6, //
|
||||||
|
2, 4, //
|
||||||
|
3, -2, //
|
||||||
|
10, 1, //
|
||||||
|
});
|
||||||
|
m2.Invoke();
|
||||||
|
EXPECT_THAT(m2.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
0.645656, 0.354344, //
|
||||||
|
0.450166, 0.549834, //
|
||||||
|
0.622459, 0.377541, //
|
||||||
|
0.710949, 0.28905, //
|
||||||
|
},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test quantized softmax with int8 input and output. With the same input as in
|
||||||
|
// QuantizedActivationsOpTest.Softmax4D, the dequantized output is identical.
|
||||||
|
TEST(QuantizedActivationsOpTest, Softmax4DInt8) {
|
||||||
|
QuantizedActivationsOpModel m(
|
||||||
|
0.1,
|
||||||
|
/*input=*/{TensorType_INT8, {1, 2, 1, 4}, -10, 10});
|
||||||
|
m.SetInput<int8_t>({
|
||||||
|
0, -6, 2, 4, // depth = 0
|
||||||
|
3, -2, 10, 1, // depth = 1
|
||||||
|
});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({
|
||||||
|
-68, -95, -54, -38, //
|
||||||
|
-70, -93, -12, -81, //
|
||||||
|
}));
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
.23463, .12877, .28658, .35003, //
|
||||||
|
.22528, .13664, .45365, .18443, //
|
||||||
|
},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
|
||||||
|
// Same input, but a different shape.
|
||||||
|
QuantizedActivationsOpModel m2(
|
||||||
|
0.1,
|
||||||
|
/*input=*/{TensorType_INT8, {4, 1, 1, 2}, -10, 10});
|
||||||
|
m2.SetInput<int8_t>({
|
||||||
|
0, -6, //
|
||||||
|
2, 4, //
|
||||||
|
3, -2, //
|
||||||
|
10, 1, //
|
||||||
|
});
|
||||||
|
m2.Invoke();
|
||||||
|
EXPECT_THAT(m2.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{
|
||||||
|
0.645656, 0.354344, //
|
||||||
|
0.450166, 0.549834, //
|
||||||
|
0.622459, 0.377541, //
|
||||||
|
0.710949, 0.28905, //
|
||||||
|
},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(FloatActivationsOpTest, Softmax3D) {
|
TEST(FloatActivationsOpTest, Softmax3D) {
|
||||||
FloatActivationsOpModel m(0.1,
|
FloatActivationsOpModel m(0.1,
|
||||||
/*input=*/{TensorType_FLOAT32, {1, 2, 4}});
|
/*input=*/{TensorType_FLOAT32, {1, 2, 4}});
|
||||||
@ -393,7 +534,7 @@ TEST(FloatActivationsOpTest, Softmax3D) {
|
|||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QuantizedActivationsOpTest, Softmax3D) {
|
TEST(QuantizedActivationsOpTest, Softmax3DUint8) {
|
||||||
QuantizedActivationsOpModel m(
|
QuantizedActivationsOpModel m(
|
||||||
0.1,
|
0.1,
|
||||||
/*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10});
|
/*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10});
|
||||||
@ -443,7 +584,7 @@ TEST(FloatActivationsOpTest, Softmax1D) {
|
|||||||
{.09752, .05352, .11911, .14548, .13164, .07984, .26509, .10778})));
|
{.09752, .05352, .11911, .14548, .13164, .07984, .26509, .10778})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QuantizedActivationsOpTest, Softmax1D) {
|
TEST(QuantizedActivationsOpTest, Softmax1DUint8) {
|
||||||
QuantizedActivationsOpModel m(0.1,
|
QuantizedActivationsOpModel m(0.1,
|
||||||
/*input=*/{TensorType_UINT8, {8}, -10, 10});
|
/*input=*/{TensorType_UINT8, {8}, -10, 10});
|
||||||
m.SetInput<uint8_t>({0, -6, 2, 4, 3, -2, 10, 1});
|
m.SetInput<uint8_t>({0, -6, 2, 4, 3, -2, 10, 1});
|
||||||
@ -486,7 +627,7 @@ TEST(FloatActivationsOpTest, Softmax2D) {
|
|||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QuantizedActivationsOpTest, Softmax2D) {
|
TEST(QuantizedActivationsOpTest, Softmax2DUint8) {
|
||||||
QuantizedActivationsOpModel m(0.1,
|
QuantizedActivationsOpModel m(0.1,
|
||||||
/*input=*/{TensorType_UINT8, {2, 4}, -10, 10});
|
/*input=*/{TensorType_UINT8, {2, 4}, -10, 10});
|
||||||
m.SetInput<uint8_t>({
|
m.SetInput<uint8_t>({
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
|
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||||
|
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
||||||
|
|
||||||
package(default_visibility = [
|
package(default_visibility = [
|
||||||
"//visibility:public",
|
"//visibility:public",
|
||||||
])
|
])
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
|
||||||
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
|
||||||
|
|
||||||
tflite_deps_intel = [
|
tflite_deps_intel = [
|
||||||
"@arm_neon_2_x86_sse",
|
"@arm_neon_2_x86_sse",
|
||||||
]
|
]
|
||||||
@ -314,6 +314,7 @@ cc_library(
|
|||||||
"reference/depthwiseconv_uint8.h",
|
"reference/depthwiseconv_uint8.h",
|
||||||
"reference/fully_connected.h",
|
"reference/fully_connected.h",
|
||||||
"reference/integer_ops/dequantize.h",
|
"reference/integer_ops/dequantize.h",
|
||||||
|
"reference/integer_ops/softmax.h",
|
||||||
"reference/reference_ops.h",
|
"reference/reference_ops.h",
|
||||||
"reference/softmax.h",
|
"reference/softmax.h",
|
||||||
],
|
],
|
||||||
|
@ -131,6 +131,23 @@ int CountLeadingZeros(T integer_input) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int32 GetReciprocal(int32 x, int x_integer_digits,
|
||||||
|
int* num_bits_over_unit) {
|
||||||
|
int headroom_plus_one = CountLeadingZeros(static_cast<uint32>(x));
|
||||||
|
// This is the number of bits to the left of the binary point above 1.0.
|
||||||
|
// Consider x=1.25. In that case shifted_scale=0.8 and
|
||||||
|
// no later adjustment will be needed.
|
||||||
|
*num_bits_over_unit = x_integer_digits - headroom_plus_one;
|
||||||
|
const int32 shifted_sum_minus_one =
|
||||||
|
static_cast<int32>((static_cast<uint32>(x) << headroom_plus_one) -
|
||||||
|
(static_cast<uint32>(1) << 31));
|
||||||
|
|
||||||
|
gemmlowp::FixedPoint<int32, 0> shifted_scale =
|
||||||
|
gemmlowp::one_over_one_plus_x_for_x_in_0_1(
|
||||||
|
gemmlowp::FixedPoint<int32, 0>::FromRaw(shifted_sum_minus_one));
|
||||||
|
return shifted_scale.raw();
|
||||||
|
}
|
||||||
|
|
||||||
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
|
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
|
||||||
// BROADCASTING.
|
// BROADCASTING.
|
||||||
//
|
//
|
||||||
|
102
tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h
Normal file
102
tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_SOFTMAX_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_SOFTMAX_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_integer_ops {
|
||||||
|
|
||||||
|
// Quantized softmax with int8 input and output.
|
||||||
|
inline void Softmax(const SoftmaxParams& params,
|
||||||
|
const RuntimeShape& input_shape, const int8* input_data,
|
||||||
|
const RuntimeShape& output_shape, int8* output_data) {
|
||||||
|
const int32 input_beta_multiplier = params.input_multiplier;
|
||||||
|
const int32 input_beta_left_shift = params.input_left_shift;
|
||||||
|
const int diff_min = params.diff_min;
|
||||||
|
// The representation chosen for the input to the exp() function is Q5.26.
|
||||||
|
// We need to leave extra space since values that we skip might be as large as
|
||||||
|
// -32 before multiplying by input_beta_multiplier, and therefore as large as
|
||||||
|
// -16 afterwards. Note that exp(-8) is definitely not insignificant to
|
||||||
|
// accumulation, but exp(-16) definitely is.
|
||||||
|
static const int kScaledDiffIntegerBits = 5;
|
||||||
|
static const int kAccumulationIntegerBits = 12;
|
||||||
|
using FixedPointScaledDiff =
|
||||||
|
gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
|
||||||
|
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
|
||||||
|
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
|
||||||
|
|
||||||
|
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||||
|
const int outer_size =
|
||||||
|
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||||
|
const int depth =
|
||||||
|
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||||
|
|
||||||
|
for (int i = 0; i < outer_size; ++i) {
|
||||||
|
int8 max_in_row = -128;
|
||||||
|
for (int c = 0; c < depth; ++c) {
|
||||||
|
max_in_row = std::max(max_in_row, input_data[i * depth + c]);
|
||||||
|
}
|
||||||
|
|
||||||
|
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
|
||||||
|
for (int c = 0; c < depth; ++c) {
|
||||||
|
int32 input_diff =
|
||||||
|
static_cast<int32>(input_data[i * depth + c]) - max_in_row;
|
||||||
|
if (input_diff >= diff_min) {
|
||||||
|
const int32 input_diff_rescaled =
|
||||||
|
MultiplyByQuantizedMultiplierGreaterThanOne(
|
||||||
|
input_diff, input_beta_multiplier, input_beta_left_shift);
|
||||||
|
const FixedPointScaledDiff scaled_diff_f8 =
|
||||||
|
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
||||||
|
sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
|
||||||
|
exp_on_negative_values(scaled_diff_f8));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_bits_over_unit;
|
||||||
|
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
|
||||||
|
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
|
||||||
|
|
||||||
|
for (int c = 0; c < depth; ++c) {
|
||||||
|
int32 input_diff =
|
||||||
|
static_cast<int32>(input_data[i * depth + c]) - max_in_row;
|
||||||
|
if (input_diff >= diff_min) {
|
||||||
|
const int32 input_diff_rescaled =
|
||||||
|
MultiplyByQuantizedMultiplierGreaterThanOne(
|
||||||
|
input_diff, input_beta_multiplier, input_beta_left_shift);
|
||||||
|
const FixedPointScaledDiff scaled_diff_f8 =
|
||||||
|
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
|
||||||
|
|
||||||
|
FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
|
||||||
|
const int32 unsat_output = gemmlowp::RoundingDivideByPOT(
|
||||||
|
(shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
|
||||||
|
const int32 shifted_output = unsat_output - 128;
|
||||||
|
|
||||||
|
output_data[i * depth + c] = static_cast<int8>(
|
||||||
|
std::max(std::min(shifted_output, static_cast<int32>(127)),
|
||||||
|
static_cast<int32>(-128)));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
output_data[i * depth + c] = -128;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_integer_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_SOFTMAX_H_
|
@ -102,19 +102,9 @@ inline void Softmax(const SoftmaxParams& params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32 fixed_sum_of_exps = sum_of_exps.raw();
|
int num_bits_over_unit;
|
||||||
int headroom_plus_one =
|
FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
|
||||||
CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
|
sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
|
||||||
// This is the number of bits to the left of the binary point above 1.0.
|
|
||||||
// Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
|
|
||||||
// no later adjustment will be needed.
|
|
||||||
int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
|
|
||||||
int32 shifted_sum_minus_one = static_cast<int32>(
|
|
||||||
(static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
|
|
||||||
(static_cast<uint32>(1) << 31));
|
|
||||||
|
|
||||||
FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
|
|
||||||
FixedPoint0::FromRaw(shifted_sum_minus_one));
|
|
||||||
|
|
||||||
for (int c = 0; c < depth; ++c) {
|
for (int c = 0; c < depth; ++c) {
|
||||||
int32 input_diff =
|
int32 input_diff =
|
||||||
|
@ -706,6 +706,11 @@ class Softmax
|
|||||||
}
|
}
|
||||||
|
|
||||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
|
const string& input_name = op_signature.op->inputs[0];
|
||||||
|
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||||
|
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1556,10 +1561,6 @@ class TensorFlowUnsupported : public BaseOperator {
|
|||||||
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
|
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(wvo): hack to make this code compile with 2 different API versions.
|
|
||||||
// Please remove once OS/internal versions are in sync.
|
|
||||||
// See hardcoded values in the switch below.
|
|
||||||
|
|
||||||
void ReadOptions(const flexbuffers::Map& m,
|
void ReadOptions(const flexbuffers::Map& m,
|
||||||
TensorFlowUnsupportedOperator* op) const {
|
TensorFlowUnsupportedOperator* op) const {
|
||||||
::tensorflow::NodeDef node_def;
|
::tensorflow::NodeDef node_def;
|
||||||
@ -1569,6 +1570,10 @@ class TensorFlowUnsupported : public BaseOperator {
|
|||||||
for (size_t i = 0; i < keys.size(); ++i) {
|
for (size_t i = 0; i < keys.size(); ++i) {
|
||||||
const auto key = keys[i].AsKey();
|
const auto key = keys[i].AsKey();
|
||||||
const auto& value = m[key];
|
const auto& value = m[key];
|
||||||
|
// TODO(wvo): hack to make this code compile with 2 different API
|
||||||
|
// versions.
|
||||||
|
// Please remove once OS/internal versions are in sync.
|
||||||
|
// See hardcoded values in the switch below.
|
||||||
switch (value.GetType()) {
|
switch (value.GetType()) {
|
||||||
case 5: // flexbuffers::FBT_STRING:
|
case 5: // flexbuffers::FBT_STRING:
|
||||||
(*attr)[key].set_s(value.AsString().c_str());
|
(*attr)[key].set_s(value.AsString().c_str());
|
||||||
|
Loading…
Reference in New Issue
Block a user