Add another integer LSTM reference runtime for special hardware platform. Its quantization will be added separately.

PiperOrigin-RevId: 296320108
Change-Id: Ifb3a22667322a5af49426b5a1fc4066a50beac51
This commit is contained in:
Jian Li 2020-02-20 16:07:36 -08:00 committed by TensorFlower Gardener
parent 10aff5d518
commit 7f6685951b
12 changed files with 2096 additions and 71 deletions

View File

@ -614,12 +614,12 @@ cc_library(
":op_macros",
"//tensorflow/lite/c:common",
"//tensorflow/lite/experimental/ruy/profiler:instrumentation",
"//tensorflow/lite/kernels/internal:common",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:optimized_base",
"//tensorflow/lite/kernels/internal:quantization_util",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:tensor_utils",
"@gemmlowp",
],
)

View File

@ -111,6 +111,31 @@ void MatrixBatchVectorMultiplyAccumulate(
n_output, output_zp, scratch, output, context);
}
void MatrixBatchVectorMultiply(const int8_t* input, int32_t input_zeropoint,
const int8_t* input_to_gate_weights,
int32_t input_to_gate_effective_scale_a,
int32_t input_to_gate_effective_scale_b,
int32_t n_batch, int32_t n_input, int32_t n_cell,
int8_t* gate_output, int8_t gate_output_zp) {
PortableMatrixBatchVectorMultiply(
input, input_zeropoint, input_to_gate_weights,
input_to_gate_effective_scale_a, input_to_gate_effective_scale_b, n_batch,
n_input, n_cell, gate_output, gate_output_zp);
}
void MatrixBatchVectorMultiply(const int16_t* hidden,
const int8_t* hidden_to_output_weights,
int32_t proj_effective_scale_a,
int32_t proj_effective_scale_b,
const int32_t* gate_bias, int32_t n_batch,
int32_t n_hidden, int32_t n_output,
int32_t output_zp, int8_t* proj_output) {
PortableMatrixBatchVectorMultiply(hidden, hidden_to_output_weights,
proj_effective_scale_a,
proj_effective_scale_b, gate_bias, n_batch,
n_hidden, n_output, output_zp, proj_output);
}
void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
int32_t n_row, int32_t n_col,
int32_t* output) {
@ -127,16 +152,36 @@ void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
n_batch, n_input, output);
}
void ApplyLayerNormFloat(const int16_t* input,
const int16_t* layer_norm_weights,
int32_t layer_norm_scale_a, int32_t layer_norm_scale_b,
const int32_t* bias, int n_batch, int n_input,
int16_t* output) {
PortableApplyLayerNormFloat(input, layer_norm_weights, layer_norm_scale_a,
layer_norm_scale_b, bias, n_batch, n_input,
output);
}
void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output) {
NEON_OR_PORTABLE(ApplySigmoid, input, n_batch, n_input, output);
}
void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output) {
PortableApplySigmoidFloat(input, n_batch, n_input, output);
}
void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output) {
NEON_OR_PORTABLE(ApplyTanh, integer_bits, input, n_batch, n_input, output);
}
void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
int32_t integer_bits, int16_t* output) {
PortableApplyTanhFloat(input, n_batch, n_input, integer_bits, output);
}
void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
int n_input, int shift, int16_t* output) {
NEON_OR_PORTABLE(CwiseMul, input_1, input_2, n_batch, n_input, shift, output);
@ -260,6 +305,19 @@ void MeanStddevNormalization(const float* input_vector, float* output_vector,
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
}
void TwoGateSaturationgAdd(const int8_t* input, int8_t input_zp,
const int8_t* recurrent, int8_t recurrent_zp,
int32_t input_effective_scale_a,
int32_t input_effective_scale_b,
int32_t recurrent_effective_scale_a,
int32_t recurrent_effective_scale_b, int32_t n_batch,
int32_t n_cell, int16_t* output) {
PortableTwoGateSaturationgAdd(
input, input_zp, recurrent, recurrent_zp, input_effective_scale_a,
input_effective_scale_b, recurrent_effective_scale_a,
recurrent_effective_scale_b, n_batch, n_cell, output);
}
} // namespace tensor_utils
} // namespace tflite

View File

@ -126,6 +126,31 @@ void MatrixBatchVectorMultiplyAccumulate(
shift, n_batch, n_input, n_output, output_zp, scratch, output, context);
}
void MatrixBatchVectorMultiply(const int8_t* input, int32_t input_zeropoint,
const int8_t* input_to_gate_weights,
int32_t input_to_gate_effective_scale_a,
int32_t input_to_gate_effective_scale_b,
int32_t n_batch, int32_t n_input, int32_t n_cell,
int8_t* gate_output, int8_t gate_output_zp) {
PortableMatrixBatchVectorMultiply(
input, input_zeropoint, input_to_gate_weights,
input_to_gate_effective_scale_a, input_to_gate_effective_scale_b, n_batch,
n_input, n_cell, gate_output, gate_output_zp);
}
void MatrixBatchVectorMultiply(const int16_t* hidden,
const int8_t* hidden_to_output_weights,
int32_t proj_effective_scale_a,
int32_t proj_effective_scale_b,
const int32_t* gate_bias, int32_t n_batch,
int32_t n_hidden, int32_t n_output,
int32_t output_zp, int8_t* proj_output) {
PortableMatrixBatchVectorMultiply(hidden, hidden_to_output_weights,
proj_effective_scale_a,
proj_effective_scale_b, gate_bias, n_batch,
n_hidden, n_output, output_zp, proj_output);
}
void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
int32_t n_row, int32_t n_col,
int32_t* output) {
@ -141,16 +166,36 @@ void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
output);
}
void ApplyLayerNormFloat(const int16_t* input,
const int16_t* layer_norm_weights,
int32_t layer_norm_scale_a, int32_t layer_norm_scale_b,
const int32_t* bias, int n_batch, int n_input,
int16_t* output) {
PortableApplyLayerNormFloat(input, layer_norm_weights, layer_norm_scale_a,
layer_norm_scale_b, bias, n_batch, n_input,
output);
}
void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output) {
PortableApplySigmoid(input, n_batch, n_input, output);
}
void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output) {
PortableApplySigmoidFloat(input, n_batch, n_input, output);
}
void ApplyTanh(int32_t intger_bits, const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output) {
PortableApplyTanh(intger_bits, input, n_batch, n_input, output);
}
void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
int32_t integer_bits, int16_t* output) {
PortableApplyTanhFloat(input, n_batch, n_input, integer_bits, output);
}
void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
int n_input, int shift, int16_t* output) {
PortableCwiseMul(input_1, input_2, n_batch, n_input, shift, output);
@ -274,6 +319,19 @@ void MeanStddevNormalization(const float* input_vector, float* output_vector,
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
}
void TwoGateSaturationgAdd(const int8_t* input, int8_t input_zp,
const int8_t* recurrent, int8_t recurrent_zp,
int32_t input_effective_scale_a,
int32_t input_effective_scale_b,
int32_t recurrent_effective_scale_a,
int32_t recurrent_effective_scale_b, int32_t n_batch,
int32_t n_cell, int16_t* output) {
PortableTwoGateSaturationgAdd(
input, input_zp, recurrent, recurrent_zp, input_effective_scale_a,
input_effective_scale_b, recurrent_effective_scale_a,
recurrent_effective_scale_b, n_batch, n_cell, output);
}
} // namespace tensor_utils
} // namespace tflite

View File

@ -340,6 +340,74 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
n_output, output_zp, output);
}
void PortableMatrixBatchVectorMultiply(const int8_t* input,
int32_t input_zeropoint,
const int8_t* input_to_gate_weights,
int32_t input_to_gate_effective_scale_a,
int32_t input_to_gate_effective_scale_b,
int32_t n_batch, int32_t n_input,
int32_t n_cell, int8_t* gate_output,
int8_t gate_output_zp) {
const int32_t int8_max = std::numeric_limits<int8>::max();
const int32_t int8_min = std::numeric_limits<int8>::min();
for (int batch = 0; batch < n_batch; ++batch) {
for (int row = 0; row < n_cell; ++row) {
int32_t acc = 0;
for (int col = 0; col < n_input; ++col) {
int32_t input_val = input[batch * n_input + col];
int8_t weights_val = input_to_gate_weights[row * n_input + col];
acc += (input_val - input_zeropoint) * weights_val;
}
acc = MultiplyByQuantizedMultiplier(acc, input_to_gate_effective_scale_a,
input_to_gate_effective_scale_b);
acc += gate_output_zp;
if (acc > int8_max) {
acc = int8_max;
}
if (acc < int8_min) {
acc = int8_min;
}
gate_output[batch * n_cell + row] = static_cast<int8_t>(acc);
}
}
}
void PortableMatrixBatchVectorMultiply(
const int16_t* hidden, const int8_t* hidden_to_output_weights,
int32_t proj_effective_scale_a, int32_t proj_effective_scale_b,
const int32_t* gate_bias, int32_t n_batch, int32_t n_hidden,
int32_t n_output, int32_t output_zp, int8_t* proj_output) {
const int16_t int8_max = std::numeric_limits<int8>::max();
const int16_t int8_min = std::numeric_limits<int8>::min();
for (int batch = 0; batch < n_batch; ++batch) {
for (int row = 0; row < n_output; ++row) {
int64_t acc = gate_bias[row];
for (int col = 0; col < n_hidden; ++col) {
int16_t input_val = hidden[batch * n_hidden + col];
int8_t weights_val = hidden_to_output_weights[row * n_hidden + col];
int64_t curr = acc;
acc += input_val * weights_val;
if (input_val * weights_val > 0 && acc < curr) {
acc = std::numeric_limits<int32>::max();
}
if (input_val * weights_val < 0 && acc > curr) {
acc = std::numeric_limits<int32>::min();
}
}
acc = MultiplyByQuantizedMultiplier(acc, proj_effective_scale_a,
proj_effective_scale_b);
acc += output_zp;
if (acc > int8_max) {
acc = int8_max;
}
if (acc < int8_min) {
acc = int8_min;
}
proj_output[batch * n_output + row] = acc;
}
}
}
void PortableApplyLayerNorm(const int16_t* input,
const int16_t* layer_norm_weights,
const int32_t* bias, int32_t layer_norm_scale_a,
@ -390,6 +458,52 @@ void PortableApplyLayerNorm(const int16_t* input,
}
}
void PortableApplyLayerNormFloat(const int16_t* input,
const int16_t* layer_norm_weights,
int32_t layer_norm_scale_a,
int32_t layer_norm_scale_b,
const int32_t* bias, int n_batch, int n_input,
int16_t* output) {
const int32_t int16_max = std::numeric_limits<int16>::max();
const int32_t int16_min = std::numeric_limits<int16>::min();
// This is to surpress a lint warning.
const double two = 2.0;
const float layer_norm_scale =
layer_norm_scale_a *
std::pow(two, static_cast<double>(layer_norm_scale_b - 31));
const float bias_scale = std::pow(two, -10) * layer_norm_scale;
for (int batch = 0; batch < n_batch; ++batch) {
float sum = 0.0f;
float sum_sq = 0.0f;
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
const float value = static_cast<float>(input[index]);
sum += value;
sum_sq += value * value;
}
const float mean = sum / n_input;
float stddev_inv = 0.0f;
const float variance = sum_sq / n_input - mean * mean;
if (variance == 0) {
stddev_inv = 1.0f / sqrt(1e-8);
} else {
stddev_inv = 1.0f / sqrt(variance);
}
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
const float normalized_value =
(static_cast<float>(input[index]) - mean) * stddev_inv;
const float weighted_normalized_value =
normalized_value * layer_norm_weights[i] * layer_norm_scale +
bias[i] * bias_scale;
const int32_t quant_output = static_cast<int32>(
std::round(weighted_normalized_value * std::pow(2, 12)));
output[index] = std::min(int16_max, std::max(int16_min, quant_output));
}
}
}
void PortableMatrixScalarMultiplyAccumulate(const int8_t* matrix,
int32_t scalar, int32_t n_row,
int32_t n_col, int32_t* output) {
@ -416,6 +530,24 @@ void PortableApplySigmoid(const int16_t* input, int32_t n_batch,
}
}
void PortableApplySigmoidFloat(const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output) {
const int32_t int16_max = std::numeric_limits<int16>::max();
const int32_t int16_min = std::numeric_limits<int16>::min();
for (int batch = 0; batch < n_batch; ++batch) {
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
const float float_input = input[index] * std::pow(2, -12);
const float float_output = 1.0f / (1.0f + std::exp(-float_input));
const int32_t quant_output =
static_cast<int32>(float_output * std::pow(2, 15));
const int32_t quant_output_clamped =
std::min(int16_max, std::max(int16_min, quant_output));
output[index] = static_cast<int16>(quant_output_clamped);
}
}
}
template <int IntegerBits>
void PortableApplyTanhImpl(const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output) {
@ -452,6 +584,27 @@ void PortableApplyTanh(int32_t integer_bits, const int16_t* input,
#undef DISPATCH_TANH
}
void PortableApplyTanhFloat(const int16_t* input, int32_t n_batch,
int32_t n_input, int32_t integer_bits,
int16_t* output) {
const int32_t int16_max = std::numeric_limits<int16>::max();
const int32_t int16_min = std::numeric_limits<int16>::min();
const double two = 2.0;
for (int batch = 0; batch < n_batch; ++batch) {
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
const float float_input =
input[index] * std::pow(two, static_cast<double>(integer_bits));
const float float_output = std::tanh(float_input);
const int32_t quant_output =
static_cast<int32>(float_output * std::pow(2, 15));
const int32_t quant_output_clamped =
std::min(int16_max, std::max(int16_min, quant_output));
output[index] = static_cast<int16>(quant_output_clamped);
}
}
}
void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
int n_batch, int n_input, int shift, int16_t* output) {
for (int batch = 0; batch < n_batch; ++batch) {
@ -666,5 +819,34 @@ void PortableMeanStddevNormalization(const float* input_vector,
}
}
void PortableTwoGateSaturationgAdd(const int8_t* input, int8_t input_zp,
const int8_t* recurrent, int8_t recurrent_zp,
int32_t input_effective_scale_a,
int32_t input_effective_scale_b,
int32_t recurrent_effective_scale_a,
int32_t recurrent_effective_scale_b,
int32_t n_batch, int32_t n_cell,
int16_t* output) {
const int32_t int16_max = std::numeric_limits<int16>::max();
const int32_t int16_min = std::numeric_limits<int16>::min();
for (int i = 0; i < n_batch * n_cell; ++i) {
int32_t x = static_cast<int32>(input[i]) - static_cast<int32>(input_zp);
int32_t h =
static_cast<int32>(recurrent[i]) - static_cast<int32>(recurrent_zp);
int32_t x_scaled = MultiplyByQuantizedMultiplier(x, input_effective_scale_a,
input_effective_scale_b);
int32_t h_scaled = MultiplyByQuantizedMultiplier(
h, recurrent_effective_scale_a, recurrent_effective_scale_b);
int32_t y = h_scaled + x_scaled;
if (y > int16_max) {
y = int16_max;
}
if (y < int16_min) {
y = int16_min;
}
output[i] = static_cast<int16_t>(y);
}
}
} // namespace tensor_utils
} // namespace tflite

View File

@ -152,6 +152,31 @@ void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
PortableMatrixScalarMultiplyAccumulate(matrix, scalar, n_row, n_col, output);
}
void MatrixBatchVectorMultiply(const int8_t* input, int32_t input_zeropoint,
const int8_t* input_to_gate_weights,
int32_t input_to_gate_effective_scale_a,
int32_t input_to_gate_effective_scale_b,
int32_t n_batch, int32_t n_input, int32_t n_cell,
int8_t* gate_output, int8_t gate_output_zp) {
PortableMatrixBatchVectorMultiply(
input, input_zeropoint, input_to_gate_weights,
input_to_gate_effective_scale_a, input_to_gate_effective_scale_b, n_batch,
n_input, n_cell, gate_output, gate_output_zp);
}
void MatrixBatchVectorMultiply(const int16_t* hidden,
const int8_t* hidden_to_output_weights,
int32_t proj_effective_scale_a,
int32_t proj_effective_scale_b,
const int32_t* gate_bias, int32_t n_batch,
int32_t n_hidden, int32_t n_output,
int32_t output_zp, int8_t* proj_output) {
PortableMatrixBatchVectorMultiply(hidden, hidden_to_output_weights,
proj_effective_scale_a,
proj_effective_scale_b, gate_bias, n_batch,
n_hidden, n_output, output_zp, proj_output);
}
void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
const int32_t* bias, int32_t layer_norm_scale_a,
int32_t layer_norm_scale_b, int32_t variance_limit,
@ -161,16 +186,36 @@ void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
output);
}
void ApplyLayerNormFloat(const int16_t* input,
const int16_t* layer_norm_weights,
int32_t layer_norm_scale_a, int32_t layer_norm_scale_b,
const int32_t* bias, int n_batch, int n_input,
int16_t* output) {
PortableApplyLayerNormFloat(input, layer_norm_weights, layer_norm_scale_a,
layer_norm_scale_b, bias, n_batch, n_input,
output);
}
void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output) {
PortableApplySigmoid(input, n_batch, n_input, output);
}
void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output) {
PortableApplySigmoidFloat(input, n_batch, n_input, output);
}
void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output) {
PortableApplyTanh(integer_bits, input, n_batch, n_input, output);
}
void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
int32_t integer_bits, int16_t* output) {
PortableApplyTanhFloat(input, n_batch, n_input, integer_bits, output);
}
void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
int n_input, int shift, int16_t* output) {
PortableCwiseMul(input_1, input_2, n_batch, n_input, shift, output);
@ -265,6 +310,19 @@ void MeanStddevNormalization(const float* input_vector, float* output_vector,
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
}
void TwoGateSaturationgAdd(const int8_t* input, int8_t input_zp,
const int8_t* recurrent, int8_t recurrent_zp,
int32_t input_effective_scale_a,
int32_t input_effective_scale_b,
int32_t recurrent_effective_scale_a,
int32_t recurrent_effective_scale_b, int32_t n_batch,
int32_t n_cell, int16_t* output) {
PortableTwoGateSaturationgAdd(
input, input_zp, recurrent, recurrent_zp, input_effective_scale_a,
input_effective_scale_b, recurrent_effective_scale_a,
recurrent_effective_scale_b, n_batch, n_cell, output);
}
} // namespace tensor_utils
} // namespace tflite

View File

@ -122,6 +122,21 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
int32_t* scratch, int8_t* output, CpuBackendContext* context);
void PortableMatrixBatchVectorMultiply(const int8_t* input,
int32_t input_zeropoint,
const int8_t* input_to_gate_weights,
int32_t input_to_gate_effective_scale_a,
int32_t input_to_gate_effective_scale_b,
int32_t n_batch, int32_t n_input,
int32_t n_cell, int8_t* gate_output,
int8_t gate_output_zp);
void PortableMatrixBatchVectorMultiply(
const int16_t* hidden, const int8_t* hidden_to_output_weights,
int32_t proj_effective_scale_a, int32_t proj_effective_scale_b,
const int32_t* gate_bias, int32_t n_batch, int32_t n_hidden,
int32_t n_output, int32_t output_zp, int8_t* proj_output);
void PortableMatrixScalarMultiplyAccumulate(const int8_t* matrix,
int32_t scalar, int32_t n_row,
int32_t n_col, int32_t* output);
@ -132,12 +147,26 @@ void PortableApplyLayerNorm(const int16_t* input,
int32_t layer_norm_scale_b, int32_t variance_limit,
int n_batch, int n_input, int16_t* output);
void PortableApplyLayerNormFloat(const int16_t* input,
const int16_t* layer_norm_weights,
int32_t layer_norm_scale_a,
int32_t layer_norm_scale_b,
const int32_t* bias, int n_batch, int n_input,
int16_t* output);
void PortableApplySigmoid(const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output);
void PortableApplySigmoidFloat(const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output);
void PortableApplyTanh(int32_t integer_bits, const int16_t* input,
int32_t n_batch, int32_t n_input, int16_t* output);
void PortableApplyTanhFloat(const int16_t* input, int32_t n_batch,
int32_t n_input, int32_t integer_bits,
int16_t* output);
void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
int n_batch, int n_input, int shift, int16_t* output);
@ -197,6 +226,16 @@ void PortableMeanStddevNormalization(const float* input_vector,
float* output_vector, int v_size,
int n_batch);
// Saturate Add.
void PortableTwoGateSaturationgAdd(const int8_t* input, int8_t input_zp,
const int8_t* recurrent, int8_t recurrent_zp,
int32_t input_effective_scale_a,
int32_t input_effective_scale_b,
int32_t recurrent_effective_scale_a,
int32_t recurrent_effective_scale_b,
int32_t n_batch, int32_t n_cell,
int16_t* output);
} // namespace tensor_utils
} // namespace tflite

View File

@ -209,6 +209,27 @@ void MatrixBatchVectorMultiplyAccumulate(
int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
int32_t* scratch, int8_t* output, CpuBackendContext* context);
// Same as the above 8, 8, 8 integer matmul except for the presence of zero
// point and non-accumulative.
// TODO(b/148688698): remove this function by folding zero point calculation in
// prepare() function.
void MatrixBatchVectorMultiply(const int8_t* input, int32_t input_zeropoint,
const int8_t* input_to_gate_weights,
int32_t input_to_gate_effective_scale_a,
int32_t input_to_gate_effective_scale_b,
int32_t n_batch, int32_t n_input, int32_t n_cell,
int8_t* gate_output, int8_t gate_output_zp);
// Same as above but has 16 bit and 8 bit input and 8 bit output.
// Used in projection when hidden is 16bit.
void MatrixBatchVectorMultiply(const int16_t* hidden,
const int8_t* hidden_to_output_weights,
int32_t proj_effective_scale_a,
int32_t proj_effective_scale_b,
const int32_t* gate_bias, int32_t n_batch,
int32_t n_hidden, int32_t n_output,
int32_t output_zp, int8_t* proj_output);
// Multiplies a matrix with a scalar and reduce the result on each row to a
// scalar.
// Parameters:
@ -241,6 +262,13 @@ void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
int32_t layer_norm_scale_b, int32_t variance_limit,
int n_batch, int n_input, int16_t* output);
// Same as above but the internal calculation is done in float.
void ApplyLayerNormFloat(const int16_t* input,
const int16_t* layer_norm_weights,
int32_t layer_norm_scale_a, int32_t layer_norm_scale_b,
const int32_t* bias, int n_batch, int n_input,
int16_t* output);
// Apply Sigmoid to a quantized vector.
// Parameters:
// - input: batch vector of size n_batch * n_input; 16 bit.
@ -251,6 +279,10 @@ void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output);
// Same as above but the internal calcualtion is float.
void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
int16_t* output);
// Apply Tanh to a quantized vector.
// Parameters:
// - integer_bits: the integer bits of the input.
@ -263,6 +295,12 @@ void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch,
int32_t n_input, int16_t* output);
// Apply Tanh to a quantized vector. Tbe internal calculation is in float.
// - Input has 2^(integer_bits) as scale.
// - Output has Q0.15 as scale.
void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
int32_t integer_bits, int16_t* output);
// Element-wise multiplication of two quantized vectors.
// Parameters:
// - input_1: batch vector of size n_batch * n_input; 16 bit.
@ -553,6 +591,16 @@ void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
// Layer norm for each batch.
void MeanStddevNormalization(const float* input_vector, float* output_vector,
int v_size, int n_batch);
// Saturate Add with rescale on both inputs.
void TwoGateSaturationgAdd(const int8_t* input, int8_t input_zp,
const int8_t* recurrent, int8_t recurrent_zp,
int32_t input_effective_scale_a,
int32_t input_effective_scale_b,
int32_t recurrent_effective_scale_a,
int32_t recurrent_effective_scale_b, int32_t n_batch,
int32_t n_cell, int16_t* output);
} // namespace tensor_utils
} // namespace tflite

View File

@ -520,6 +520,102 @@ TEST(uKernels, QuantMatrixBatchVectorMultiplyAccumulate8x8_8Test) {
EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
}
// Qautnized matmul with 2 * 30 input and 9 * 30 matrix with zero point.
TEST(uKernels, QuantMatrixBatchVectorMultiply8x8_8WithZPTest) {
const int32_t input_zp = 3;
const std::vector<int8_t> input = {
4, -41, 5, -41, 22, 17, -30, 24, 13, -47, 18, 9, -11, -30, 16,
-47, 12, 36, -20, 27, -3, 0, -51, -31, 3, -8, -38, 43, 23, 12,
11, -23, -26, 23, 14, -9, -44, 22, 21, -30, 3, -47, -26, -21, -24,
-44, 34, -11, -23, -28, 26, -38, 19, 35, 9, 23, 6, -42, -25, 28,
};
const std::vector<int8_t> input_to_gate_weights = {
13, -7, -20, -22, 8, -46, 9, -2, -18, -42, 40, 28, -7, 24, 34,
-7, -24, -24, 19, 14, -19, -6, -2, -3, 5, -36, -13, 6, -27, 36,
-23, 0, 20, -37, -23, 9, 17, -41, 33, -15, -18, -42, -41, -34, -16,
-6, 12, -14, -15, -20, -14, 21, -3, -1, -26, 54, 51, 35, -14, 9,
-2, 13, -6, 39, 34, -21, 39, -51, 19, -44, 52, 0, -2, -38, -35,
-33, 4, -22, -37, 27, -23, 3, -10, 5, 32, 6, 1, -35, 24, -19,
46, 43, -55, 5, 38, -14, 32, -43, -44, -17, -13, -28, 56, 28, -42,
4, 10, -7, 25, -15, -9, -25, -14, -15, 6, -10, -22, 40, -72, 18,
-6, -18, -2, 37, -13, -10, 11, -9, 32, -28, 19, -2, 4, -31, 50,
-15, 23, -34, -9, 41, -6, -34, 17, 2, 24, -15, 21, -17, -8, -20,
1, -63, 19, -40, 12, -5, 5, -6, 1, 19, -9, -23, 5, -34, 11,
26, 21, 54, 34, -43, -29, 1, 16, 31, -56, -28, 57, -15, -23, 37,
-17, -3, -6, 29, 18, 77, 17, -20, -14, -19, 8, -24, -7, -45, -3,
0, -25, -8, 6, 9, 3, -15, 51, 4, -15, -19, -16, -14, -47, -52,
25, 9, 58, 26, -9, -27, 49, -6, -21, 21, 18, 12, -9, -9, 14,
31, -26, -19, -50, 17, 35, 11, -10, 22, -16, -43, -2, 26, 55, -20,
-7, 21, 33, -20, 26, -15, -22, 30, 27, 3, -34, 26, 12, -1, 19,
26, -25, 10, 30, 30, -14, -23, -23, -35, -16, 26, -41, 11, 1, 21,
};
const int32_t multiplier = 1347771520;
const int32_t shift = -7;
const int32_t output_zp = -11;
std::vector<int8_t> output = {1, 2, 3, 4, 5, 6, 5, 4, 3,
2, 1, 2, 8, -1, -2, 11, 17, 18};
MatrixBatchVectorMultiply(
input.data(), input_zp, input_to_gate_weights.data(), multiplier, shift,
/*n_batch=*/2, /*n_input=*/30, /*n_cell=*/9, output.data(), output_zp);
const std::vector<int8_t> expected_output = {6, -9, -4, -32, -10, -17,
-25, -25, 14, -19, 3, 10,
-12, 10, 0, 1, -57, -41};
EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
}
// Qautnized matmul with 2 * 30 input and 9 * 30 matrix with zero point.
TEST(uKernels, QuantMatrixBatchVectorMultiply16x8_8WithZPTest) {
const std::vector<int16_t> input = {
400, -41, 5, -41, 22, 17, -30, 24, 130, -47, 18, 9, -11, -30, 16,
-47, 12, 36, -20, 27, -3, 0, -51, -31, 3, -8, -38, 43, 23, 12,
11, -23, -26, 23, 14, -9, -44, 22, 21, -30, 3, -47, -26, -21, -24,
-44, 34, -11, -23, -28, 26, -38, 19, 35, 9, 23, 6, -42, -25, 28,
};
const std::vector<int8_t> input_to_gate_weights = {
13, -7, -20, -22, 8, -46, 9, -2, -18, -42, 40, 28, -7, 24, 34,
-7, -24, -24, 19, 14, -19, -6, -2, -3, 5, -36, -13, 6, -27, 36,
-23, 0, 20, -37, -23, 9, 17, -41, 33, -15, -18, -42, -41, -34, -16,
-6, 12, -14, -15, -20, -14, 21, -3, -1, -26, 54, 51, 35, -14, 9,
-2, 13, -6, 39, 34, -21, 39, -51, 19, -44, 52, 0, -2, -38, -35,
-33, 4, -22, -37, 27, -23, 3, -10, 5, 32, 6, 1, -35, 24, -19,
46, 43, -55, 5, 38, -14, 32, -43, -44, -17, -13, -28, 56, 28, -42,
4, 10, -7, 25, -15, -9, -25, -14, -15, 6, -10, -22, 40, -72, 18,
-6, -18, -2, 37, -13, -10, 11, -9, 32, -28, 19, -2, 4, -31, 50,
-15, 23, -34, -9, 41, -6, -34, 17, 2, 24, -15, 21, -17, -8, -20,
1, -63, 19, -40, 12, -5, 5, -6, 1, 19, -9, -23, 5, -34, 11,
26, 21, 54, 34, -43, -29, 1, 16, 31, -56, -28, 57, -15, -23, 37,
-17, -3, -6, 29, 18, 77, 17, -20, -14, -19, 8, -24, -7, -45, -3,
0, -25, -8, 6, 9, 3, -15, 51, 4, -15, -19, -16, -14, -47, -52,
25, 9, 58, 26, -9, -27, 49, -6, -21, 21, 18, 12, -9, -9, 14,
31, -26, -19, -50, 17, 35, 11, -10, 22, -16, -43, -2, 26, 55, -20,
-7, 21, 33, -20, 26, -15, -22, 30, 27, 3, -34, 26, 12, -1, 19,
26, -25, 10, 30, 30, -14, -23, -23, -35, -16, 26, -41, 11, 1, 21,
};
const std::vector<int32_t> input_zeropoint_times_weights = {
0, 2, 3, 4, 5, 4, 3, 2, 10,
};
const int32_t multiplier = 1347771520;
const int32_t shift = -8;
const int32_t output_zp = -11;
std::vector<int8_t> output = {1, 2, 3, 4, 5, 6, 5, 4, 3,
2, 1, 2, 8, -1, -2, 11, 17, 18};
MatrixBatchVectorMultiply(
input.data(), input_to_gate_weights.data(), multiplier, shift,
input_zeropoint_times_weights.data(),
/*n_batch=*/2, /*n_hidden=*/30, /*n_output=*/9, output_zp, output.data());
const std::vector<int8_t> expected_output = {4, -24, -5, 10, -7, -13,
-39, 2, 3, -16, -5, -1,
-12, -1, -6, -6, -33, -25};
EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
}
// Quantized matmul with 9 * 30 matrix.
TEST(uKernels, MatrixScalarMultiplyAccumulateTest) {
std::vector<int32_t> output = {
@ -585,6 +681,37 @@ TEST(uKernels, QuantApplyLayerNormTest) {
EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
}
// Quantized layer norm of n_batch = 2 and n_input = 15.
TEST(uKernels, QuantApplyLayerNormFloatTest) {
const std::vector<int16_t> input = {
-310, 596, 34, -68, 475, 92, 672, -54, -913, -200,
-1194, -836, -620, -237, 991, 533, 721, -736, -8, -941,
-372, -1084, 591, 2557, -779, 175, 582, 956, -287, 944,
};
const std::vector<int16_t> layer_norm_weights = {
21849, 22882, 20626, 23854, 24779, 26354, 12980, 26231,
23716, 27271, 24937, 22647, 24715, 22854, 19646,
};
const std::vector<int32_t> bias_weight = {
-14175520, -13805465, -16027609, -13786809, -13321033,
-14399810, -15055368, -14536623, -14508746, -13784007,
-15206609, -15125830, -14996304, -14847597, -12814379,
};
const int32_t multiplier = 1895840000;
const int32_t shift = -13;
std::vector<int16_t> output(2 * 15, 0);
ApplyLayerNormFloat(input.data(), layer_norm_weights.data(), multiplier,
shift, bias_weight.data(), 2, 15, output.data());
const std::vector<int16_t> expected_output = {
-9408, 5844, -4803, -5297, 4826, -2392, 927, -5286,
-20353, -7851, -26534, -18701, -15830, -8623, 10312, -2524,
-136, -16053, -8206, -19160, -13299, -14407, -1233, 20617,
-18594, -6736, -2272, 2597, -11620, 1566};
EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
}
// Quantized tanh with Q0.15 input and Q0.15 output.
TEST(uKernels, QuantTanh0Test) {
const std::vector<int16_t> input = {
@ -631,6 +758,29 @@ TEST(uKernels, QuantTanh3Test) {
EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
}
// Quantized tanh with float calculation.
TEST(uKernels, QuantTanhFloatTest) {
const std::vector<int16_t> input = {
-1, 0, 1, -35, 264, 289, 8, 27, -37, -1310,
-120, 127, -16, 106, 370, -583, -299, 93, -548, 548,
653, -29, -53, 1058, -52, -164, -149, -635, 201, -1297,
-145, 899, -176, -35, 264, 289, 8, 27, -37, -1310,
-120, 127, -16, 106, 370, -583, -299, 93, -548, 548,
653, -29, -53, 1058, -52, -164, -149, -635, 201, -1297,
};
std::vector<int16_t> output(4 * 15, 0);
ApplyTanhFloat(input.data(), 4, 15, -12, output.data());
const std::vector<int16_t> expected_output = {
-8, 0, 8, -279, 2109, 2308, 63, 215, -295, -10136,
-959, 1015, -127, 847, 2951, -4632, -2387, 743, -4358, 4358,
5180, -231, -423, 8280, -415, -1311, -1191, -5039, 1606, -10042,
-1159, 7078, -1407, -279, 2109, 2308, 63, 215, -295, -10136,
-959, 1015, -127, 847, 2951, -4632, -2387, 743, -4358, 4358,
5180, -231, -423, 8280, -415, -1311, -1191, -5039, 1606, -10042};
EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
}
// Quantized tanh with Q4.11 input and Q0.15 output.
TEST(uKernels, QuantTanh4Test) {
const std::vector<int16_t> input = {
@ -676,6 +826,30 @@ TEST(uKernels, QuantSigmoidTest) {
EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
}
// Quantized sigmoid with Q3.12 input and Q0.15 output.
TEST(uKernels, QuantSigmoidFloatTest) {
const std::vector<int16_t> input = {
-10500, 1398, -6963, -7404, 485, -5401, -1757, -7668, -19248,
-9692, -24249, -17923, -15840, -10026, 5249, -89, 1787, -16178,
-6691, -19524, -13439, -24048, -1123, 32767, -17267, -3378, 823,
11482, -11139, 7508, -10500, 1398, -6963, -7404, 485, -5401,
-1757, -7668, -19248, -9692, -24249, -17923, -15840, -10026, 5249,
-89, 1787, -16178, -6691, -19524, -13439, -24048, -1123, 32767,
-17267, -3378, 823, 11482, -11139, 7508,
};
std::vector<int16_t> output(4 * 15, 0);
ApplySigmoidFloat(input.data(), 4, 15, output.data());
const std::vector<int16_t> expected_output = {
2343, 19153, 5061, 4617, 17352, 6915, 12922, 4368, 295, 2811,
87, 407, 671, 2608, 25647, 16206, 19902, 619, 5352, 276,
1187, 92, 14151, 32757, 476, 9986, 18024, 30895, 2026, 28249,
2343, 19153, 5061, 4617, 17352, 6915, 12922, 4368, 295, 2811,
87, 407, 671, 2608, 25647, 16206, 19902, 619, 5352, 276,
1187, 92, 14151, 32757, 476, 9986, 18024, 30895, 2026, 28249};
EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
}
// Quantized Multiply with 16bit output and 15 bit shift.
TEST(uKernels, QuantMul16bitOut15ShiftTest) {
const std::vector<int16_t> input1 = {
@ -1745,6 +1919,33 @@ TEST(uKernels, ReductionSumVectorIntegerTest) {
EXPECT_THAT(result1, testing::ElementsAreArray({3, 6, -1, 3, 15}));
}
void TwoGateSaturationgAdd(const int8_t* input, int8_t input_zp,
const int8_t* recurrent, int8_t recurrent_zp,
int32_t input_effective_scale_a,
int32_t input_effective_scale_b,
int32_t recurrent_effective_scale_a,
int32_t recurrent_effective_scale_b, int32_t n_batch,
int32_t n_cell, int16_t* output);
TEST(uKernels, TwoGateSaturateAddTest) {
const std::vector<int8_t> input1 = {1, 2, 3, 4, 55, 66, 77};
const std::vector<int8_t> input2 = {100, 2, 3, 4, 55, 66, 77};
const int32_t input1_zp = 10;
const int32_t input2_zp = -5;
const int32_t multiplier1 = 1347771520;
const int32_t shift1 = -7;
const int32_t multiplier2 = 1047577121;
const int32_t shift2 = -6;
std::vector<int16_t> output(7);
TwoGateSaturationgAdd(input1.data(), input1_zp, input2.data(), input2_zp,
multiplier1, shift1, multiplier2, shift2, 1, 7,
output.data());
const std::vector<int16_t> expected_output = {1, 0, 0, 0, 0, 1, 1};
EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
}
namespace {
// Parameterized test: mean, difference, tolerance.
// Input is constructed as [mean-2*diff, mean-diff, mean+diff, mean+2*diff]

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -59,7 +59,8 @@ struct OpData {
namespace full {
namespace {
TfLiteStatus PopulateQuantizedLstmParams(
TfLiteStatus PopulateQuantizedLstmParams8x8_16(
TfLiteContext* context, TfLiteNode* node,
lstm_eval::IntegerLstmParameter* integer_lstm_param) {
// Calculate quantized clip for projection and cell.
@ -366,6 +367,361 @@ TfLiteStatus PopulateQuantizedLstmParams(
return kTfLiteOk;
}
TfLiteStatus PopulateQuantizedLstmParams8x8_8(
TfLiteContext* context, TfLiteNode* node,
lstm_eval::IntegerLstmParameter* integer_lstm_param) {
// Get all tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
const TfLiteTensor* cell_to_input_weights =
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
const TfLiteTensor* cell_to_forget_weights =
GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
const TfLiteTensor* cell_to_output_weights =
GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
const TfLiteTensor* input_layer_norm_coefficients =
GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor);
const TfLiteTensor* forget_layer_norm_coefficients =
GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor);
const TfLiteTensor* cell_layer_norm_coefficients =
GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor);
const TfLiteTensor* output_layer_norm_coefficients =
GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor);
const TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
TfLiteTensor* activation_state =
GetVariableInput(context, node, kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* cell_state =
GetVariableInput(context, node, kInputCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to get the condition.
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
const bool use_projection = (projection_weights != nullptr);
// Weights and states.
int8_t* input_to_input_weight_ptr = nullptr;
int8_t* recurrent_to_input_weight_ptr = nullptr;
int8_t* cell_to_input_weight_ptr = nullptr;
int8_t* input_to_forget_weight_ptr = nullptr;
int8_t* recurrent_to_forget_weight_ptr = nullptr;
int8_t* cell_to_forget_weight_ptr = nullptr;
int8_t* input_to_cell_weight_ptr = nullptr;
int8_t* recurrent_to_cell_weight_ptr = nullptr;
int8_t* input_to_output_weight_ptr = nullptr;
int8_t* recurrent_to_output_weight_ptr = nullptr;
int8_t* cell_to_output_weight_ptr = nullptr;
int8_t* proj_weight_ptr = nullptr;
int16_t* layer_norm_input_weight_ptr = nullptr;
int16_t* layer_norm_forget_weight_ptr = nullptr;
int16_t* layer_norm_cell_weight_ptr = nullptr;
int16_t* layer_norm_output_weight_ptr = nullptr;
int32_t* input_bias_ptr = nullptr;
int32_t* forget_bias_ptr = nullptr;
int32_t* cell_bias_ptr = nullptr;
int32_t* output_bias_ptr = nullptr;
int32_t* proj_bias_ptr = nullptr;
int16_t* cell_ptr = nullptr;
int8_t* activation_ptr = nullptr;
// Scales.
const float default_scale = 1.0;
float input_scale = default_scale;
float input_to_input_weight_scale = default_scale;
float recurrent_to_input_weight_scale = default_scale;
float cell_to_input_weight_scale = default_scale;
float input_to_forget_weight_scale = default_scale;
float recurrent_to_forget_weight_scale = default_scale;
float cell_to_forget_weight_scale = default_scale;
float input_to_cell_weight_scale = default_scale;
float recurrent_to_cell_weight_scale = default_scale;
float input_to_output_weight_scale = default_scale;
float recurrent_to_output_weight_scale = default_scale;
float cell_to_output_weight_scale = default_scale;
float proj_weight_scale = default_scale;
float layer_norm_input_scale = default_scale;
float layer_norm_forget_scale = default_scale;
float layer_norm_cell_scale = default_scale;
float layer_norm_output_scale = default_scale;
float activation_scale = default_scale;
// Effective scales.
float effective_input_to_input_scale = default_scale;
float effective_recurrent_to_input_scale = default_scale;
float effective_cell_to_input_scale = default_scale;
float effective_input_to_forget_scale = default_scale;
float effective_recurrent_to_forget_scale = default_scale;
float effective_cell_to_forget_scale = default_scale;
float effective_input_to_cell_scale = default_scale;
float effective_recurrent_to_cell_scale = default_scale;
float effective_input_to_output_scale = default_scale;
float effective_recurrent_to_output_scale = default_scale;
float effective_cell_to_output_scale = default_scale;
float effective_proj_scale = default_scale;
// Zero points
int input_zp = 0;
int activation_zp = 0;
// Populate all the values.
if (!use_cifg) {
input_to_input_weight_ptr = input_to_input_weights->data.int8;
recurrent_to_input_weight_ptr = recurrent_to_input_weights->data.int8;
input_bias_ptr = input_gate_bias->data.i32;
input_to_input_weight_scale = input_to_input_weights->params.scale;
recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
}
if (use_peephole) {
if (!use_cifg) {
cell_to_input_weight_ptr = cell_to_input_weights->data.int8;
cell_to_input_weight_scale = cell_to_input_weights->params.scale;
}
cell_to_forget_weight_ptr = cell_to_forget_weights->data.int8;
cell_to_output_weight_ptr = cell_to_output_weights->data.int8;
cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
cell_to_output_weight_scale = cell_to_output_weights->params.scale;
}
if (is_layer_norm_lstm) {
if (!use_cifg) {
layer_norm_input_weight_ptr = input_layer_norm_coefficients->data.i16;
layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
}
layer_norm_forget_weight_ptr = forget_layer_norm_coefficients->data.i16;
layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
layer_norm_cell_weight_ptr = cell_layer_norm_coefficients->data.i16;
layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
layer_norm_output_weight_ptr = output_layer_norm_coefficients->data.i16;
layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
}
if (use_projection) {
proj_weight_ptr = projection_weights->data.int8;
proj_weight_scale = projection_weights->params.scale;
if (projection_bias) {
proj_bias_ptr = projection_bias->data.i32;
}
}
activation_scale = activation_state->params.scale;
input_to_forget_weight_ptr = input_to_forget_weights->data.int8;
input_to_forget_weight_scale = input_to_forget_weights->params.scale;
input_to_cell_weight_ptr = input_to_cell_weights->data.int8;
input_to_cell_weight_scale = input_to_cell_weights->params.scale;
input_to_output_weight_ptr = input_to_output_weights->data.int8;
input_to_output_weight_scale = input_to_output_weights->params.scale;
recurrent_to_forget_weight_ptr = recurrent_to_forget_weights->data.int8;
recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
recurrent_to_cell_weight_ptr = recurrent_to_cell_weights->data.int8;
recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
recurrent_to_output_weight_ptr = recurrent_to_output_weights->data.int8;
recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
forget_bias_ptr = forget_gate_bias->data.i32;
cell_bias_ptr = cell_bias->data.i32;
output_bias_ptr = output_gate_bias->data.i32;
activation_ptr = activation_state->data.int8;
cell_ptr = cell_state->data.i16;
input_scale = input->params.scale;
input_zp = input->params.zero_point;
activation_zp = activation_state->params.zero_point;
std::vector<float> intermediate_scale;
for (int i = 0; i < 12; ++i) {
TfLiteTensor* intermediate =
&context->tensors[node->intermediates->data[i]];
auto* params = reinterpret_cast<TfLiteAffineQuantization*>(
intermediate->quantization.params);
intermediate_scale.push_back(params->scale->data[0]);
integer_lstm_param->intermediate_zp[i] = params->zero_point->data[0];
}
// Calculate effective scales.
if (!use_cifg) {
effective_input_to_input_scale =
input_to_input_weight_scale * input_scale / intermediate_scale[1];
effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
activation_scale /
intermediate_scale[2];
}
effective_input_to_forget_scale =
input_to_forget_weight_scale * input_scale / intermediate_scale[4];
effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
activation_scale /
intermediate_scale[5];
effective_input_to_cell_scale =
input_to_cell_weight_scale * input_scale / intermediate_scale[7];
effective_recurrent_to_cell_scale =
recurrent_to_cell_weight_scale * activation_scale / intermediate_scale[8];
effective_input_to_output_scale =
input_to_output_weight_scale * input_scale / intermediate_scale[10];
effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
activation_scale /
intermediate_scale[11];
effective_proj_scale =
proj_weight_scale * std::pow(2, -15) / activation_scale;
if (use_peephole) {
if (!use_cifg) {
effective_cell_to_input_scale =
std::pow(2, -15) * cell_to_input_weight_scale / intermediate_scale[0];
}
effective_cell_to_forget_scale =
std::pow(2, -15) * cell_to_forget_weight_scale / intermediate_scale[3];
effective_cell_to_output_scale =
std::pow(2, -15) * cell_to_output_weight_scale / intermediate_scale[9];
}
// Calculate effecgive scales.
QuantizeMultiplier(effective_input_to_input_scale,
&integer_lstm_param->effective_input_to_input_scale_a,
&integer_lstm_param->effective_input_to_input_scale_b);
QuantizeMultiplier(effective_recurrent_to_input_scale,
&integer_lstm_param->effective_recurrent_to_input_scale_a,
&integer_lstm_param->effective_recurrent_to_input_scale_b);
QuantizeMultiplier(effective_cell_to_input_scale,
&integer_lstm_param->effective_cell_to_input_scale_a,
&integer_lstm_param->effective_cell_to_input_scale_b);
QuantizeMultiplier(effective_input_to_forget_scale,
&integer_lstm_param->effective_input_to_forget_scale_a,
&integer_lstm_param->effective_input_to_forget_scale_b);
QuantizeMultiplier(
effective_recurrent_to_forget_scale,
&integer_lstm_param->effective_recurrent_to_forget_scale_a,
&integer_lstm_param->effective_recurrent_to_forget_scale_b);
QuantizeMultiplier(effective_cell_to_forget_scale,
&integer_lstm_param->effective_cell_to_forget_scale_a,
&integer_lstm_param->effective_cell_to_forget_scale_b);
QuantizeMultiplier(effective_input_to_cell_scale,
&integer_lstm_param->effective_input_to_cell_scale_a,
&integer_lstm_param->effective_input_to_cell_scale_b);
QuantizeMultiplier(effective_recurrent_to_cell_scale,
&integer_lstm_param->effective_recurrent_to_cell_scale_a,
&integer_lstm_param->effective_recurrent_to_cell_scale_b);
QuantizeMultiplier(effective_input_to_output_scale,
&integer_lstm_param->effective_input_to_output_scale_a,
&integer_lstm_param->effective_input_to_output_scale_b);
QuantizeMultiplier(
effective_recurrent_to_output_scale,
&integer_lstm_param->effective_recurrent_to_output_scale_a,
&integer_lstm_param->effective_recurrent_to_output_scale_b);
QuantizeMultiplier(effective_cell_to_output_scale,
&integer_lstm_param->effective_cell_to_output_scale_a,
&integer_lstm_param->effective_cell_to_output_scale_b);
QuantizeMultiplier(effective_proj_scale,
&integer_lstm_param->effective_proj_scale_a,
&integer_lstm_param->effective_proj_scale_b);
QuantizeMultiplier(layer_norm_input_scale,
&integer_lstm_param->layer_norm_input_scale_a,
&integer_lstm_param->layer_norm_input_scale_b);
QuantizeMultiplier(layer_norm_forget_scale,
&integer_lstm_param->layer_norm_forget_scale_a,
&integer_lstm_param->layer_norm_forget_scale_b);
QuantizeMultiplier(layer_norm_cell_scale,
&integer_lstm_param->layer_norm_cell_scale_a,
&integer_lstm_param->layer_norm_cell_scale_b);
QuantizeMultiplier(layer_norm_output_scale,
&integer_lstm_param->layer_norm_output_scale_a,
&integer_lstm_param->layer_norm_output_scale_b);
{
// Intermdiates in flatbuffer holds Wx, Wh and Wx+Wh.
// effective Wx, Wh is in effective_input/recurrent_to_<...>_scale
// So use intermediate_scale to hold scale from Wx and Wh to Wx+Wh
// 0: [1] -> [0]
// 1: [2] -> [0]
// and use intermdiate_zp as is.
const float s_1_0 = intermediate_scale[1] / intermediate_scale[0];
const float s_2_0 = intermediate_scale[2] / intermediate_scale[0];
const float s_4_3 = intermediate_scale[4] / intermediate_scale[3];
const float s_5_3 = intermediate_scale[5] / intermediate_scale[3];
const float s_7_6 = intermediate_scale[7] / intermediate_scale[6];
const float s_8_6 = intermediate_scale[8] / intermediate_scale[6];
const float s_10_9 = intermediate_scale[10] / intermediate_scale[9];
const float s_11_9 = intermediate_scale[11] / intermediate_scale[9];
QuantizeMultiplier(s_1_0, &integer_lstm_param->intermediate_scale_a[0],
&integer_lstm_param->intermediate_scale_b[0]);
QuantizeMultiplier(s_2_0, &integer_lstm_param->intermediate_scale_a[1],
&integer_lstm_param->intermediate_scale_b[1]);
QuantizeMultiplier(s_4_3, &integer_lstm_param->intermediate_scale_a[2],
&integer_lstm_param->intermediate_scale_b[2]);
QuantizeMultiplier(s_5_3, &integer_lstm_param->intermediate_scale_a[3],
&integer_lstm_param->intermediate_scale_b[3]);
QuantizeMultiplier(s_7_6, &integer_lstm_param->intermediate_scale_a[4],
&integer_lstm_param->intermediate_scale_b[4]);
QuantizeMultiplier(s_8_6, &integer_lstm_param->intermediate_scale_a[5],
&integer_lstm_param->intermediate_scale_b[5]);
QuantizeMultiplier(s_10_9, &integer_lstm_param->intermediate_scale_a[6],
&integer_lstm_param->intermediate_scale_b[6]);
QuantizeMultiplier(s_11_9, &integer_lstm_param->intermediate_scale_a[7],
&integer_lstm_param->intermediate_scale_b[7]);
}
// Calculate quantized clip for projection and cell.
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
const float cell_clip = params->cell_clip;
const float proj_clip = params->proj_clip;
const TfLiteTensor* cell_tensor =
GetInput(context, node, kInputCellStateTensor);
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
auto* cell_params = reinterpret_cast<TfLiteAffineQuantization*>(
cell_tensor->quantization.params);
auto* proj_params = reinterpret_cast<TfLiteAffineQuantization*>(
output_tensor->quantization.params);
TF_LITE_ENSURE_EQ(context, cell_params->scale->data[0], 1.0 / 32768);
if (cell_clip > 0.0 && cell_clip < 1.0) {
integer_lstm_param->quantized_cell_clip =
static_cast<int>(cell_clip / cell_params->scale->data[0]);
} else {
integer_lstm_param->quantized_cell_clip = 0;
}
if (proj_clip > 0.0) {
integer_lstm_param->quantized_proj_clip =
proj_clip / proj_params->scale->data[0];
} else {
integer_lstm_param->quantized_proj_clip = 0;
}
return kTfLiteOk;
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@ -868,11 +1224,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// The weights are of consistent type, so it suffices to check one.
const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights);
// The type of Integer LSTM.
const int num_intermediate_tensors = node->intermediates->size;
if (is_integer) {
TF_LITE_ENSURE(context, num_intermediate_tensors == 5 ||
num_intermediate_tensors == 12);
}
// We use number of intermediate tensors to distinguish the 8 bit matmul
// output and the 16 bit matmul output version.
const bool is_8x8_16 = num_intermediate_tensors == 5;
TfLiteIntArrayFree(node->temporaries);
if (is_hybrid_op) {
node->temporaries = TfLiteIntArrayCreate(8);
} else if (is_integer) {
node->temporaries = TfLiteIntArrayCreate(6);
if (is_8x8_16) {
node->temporaries = TfLiteIntArrayCreate(6);
} else {
node->temporaries = TfLiteIntArrayCreate(8);
}
} else {
node->temporaries = TfLiteIntArrayCreate(1);
}
@ -1003,42 +1373,78 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
if (is_integer) {
// Populate quantization parameters.
PopulateQuantizedLstmParams(context, node, &op_data->integer_lstm_param);
if (is_8x8_16) {
// Integer LSTM prepare function for 8x8->16.
// This code path needs 5 intermediate tensors per Op.
// Populate quantization parameters.
PopulateQuantizedLstmParams8x8_16(context, node,
&op_data->integer_lstm_param);
// Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
// and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
// buffer with size n_batch * n_cell.
//
// TODO(jianlijianli): Handle cifg case as well, which might save one
// buffer.
for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
node->temporaries->data[scratch_index] =
op_data->scratch_tensor_index + scratch_index;
TfLiteTensor* scratch_tensor =
GetTemporary(context, node, /*index=*/scratch_index);
scratch_tensor->type = kTfLiteInt16;
if (scratch_index == 4) {
scratch_tensor->type = kTfLiteInt8;
} else if (scratch_index == 5) {
scratch_tensor->type = kTfLiteInt32;
// Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
// and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
// buffer with size n_batch * n_cell.
//
// Handle cifg case as well, which might save one buffer.
for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
node->temporaries->data[scratch_index] =
op_data->scratch_tensor_index + scratch_index;
TfLiteTensor* scratch_tensor =
GetTemporary(context, node, /*index=*/scratch_index);
scratch_tensor->type = kTfLiteInt16;
if (scratch_index == 4) {
scratch_tensor->type = kTfLiteInt8;
} else if (scratch_index == 5) {
scratch_tensor->type = kTfLiteInt32;
}
scratch_tensor->allocation_type = kTfLiteArenaRw;
const int scratch_dimension[2] = {n_batch, n_cell};
if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
scratch_dimension)) {
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
scratch_buffer_size->data[0] = n_batch;
scratch_buffer_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, scratch_tensor,
scratch_buffer_size));
}
}
scratch_tensor->allocation_type = kTfLiteArenaRw;
const int scratch_dimension[2] = {n_batch, n_cell};
if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
scratch_dimension)) {
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
scratch_buffer_size->data[0] = n_batch;
scratch_buffer_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, scratch_tensor,
scratch_buffer_size));
// Populate precomputed zp * weight.
TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
context, op_data, node));
} else {
// Integer LSTM prepare function for 8x8->8.
// This code path needs 12 intermediate tensors per Op.
PopulateQuantizedLstmParams8x8_8(context, node,
&op_data->integer_lstm_param);
// Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
// and 2 8bit buffer with size n_batch * n_cell.
//
// Handle cifg case as well, which might save one buffer.
for (int scratch_index = 0; scratch_index < 8; ++scratch_index) {
node->temporaries->data[scratch_index] =
op_data->scratch_tensor_index + scratch_index;
TfLiteTensor* scratch_tensor =
GetTemporary(context, node, /*index=*/scratch_index);
if (scratch_index == 0 || scratch_index == 1) {
scratch_tensor->type = kTfLiteInt8;
} else {
scratch_tensor->type = kTfLiteInt16;
}
scratch_tensor->allocation_type = kTfLiteArenaRw;
const int scratch_dimension[2] = {n_batch, n_cell};
if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
scratch_dimension)) {
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
scratch_buffer_size->data[0] = n_batch;
scratch_buffer_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, scratch_tensor,
scratch_buffer_size));
}
}
}
// Populate precomputed zp * weight.
TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
context, op_data, node));
}
return kTfLiteOk;
}
@ -1174,26 +1580,51 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output_scratch_buffer, output,
CpuBackendContext::GetFromContext(context));
} else {
TfLiteTensor* scratch0 = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch1 = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* scratch2 = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* scratch3 = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* scratch4 = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* scratch5 = GetTemporary(context, node, /*index=*/5);
return lstm_eval::EvalInteger8x8_16(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights,
cell_to_output_weights, input_layer_norm_coefficients,
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
cell_bias, output_gate_bias, projection_weights, projection_bias,
params, &op_data->integer_lstm_param, activation_state, cell_state,
output, scratch0, scratch1, scratch2, scratch3, scratch4, scratch5,
CpuBackendContext::GetFromContext(context));
return kTfLiteOk;
const int num_intermediate_tensors = node->intermediates->size;
if (num_intermediate_tensors == 5) {
TfLiteTensor* scratch0 = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch1 = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* scratch2 = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* scratch3 = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* scratch4 = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* scratch5 = GetTemporary(context, node, /*index=*/5);
return lstm_eval::EvalInteger8x8_16(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights,
cell_to_output_weights, input_layer_norm_coefficients,
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
cell_bias, output_gate_bias, projection_weights, projection_bias,
params, &op_data->integer_lstm_param, activation_state,
cell_state, output, scratch0, scratch1, scratch2, scratch3,
scratch4, scratch5, CpuBackendContext::GetFromContext(context));
} else {
TfLiteTensor* scratch0 = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch1 = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* scratch2 = GetTemporary(context, node, /*index=*/2);
TfLiteTensor* scratch3 = GetTemporary(context, node, /*index=*/3);
TfLiteTensor* scratch4 = GetTemporary(context, node, /*index=*/4);
TfLiteTensor* scratch5 = GetTemporary(context, node, /*index=*/5);
TfLiteTensor* scratch6 = GetTemporary(context, node, /*index=*/6);
TfLiteTensor* scratch7 = GetTemporary(context, node, /*index=*/7);
return lstm_eval::EvalInteger8x8_8(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights,
cell_to_output_weights, input_layer_norm_coefficients,
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
cell_bias, output_gate_bias, projection_weights, projection_bias,
params, activation_state, cell_state, output,
&op_data->integer_lstm_param, scratch0, scratch1, scratch2,
scratch3, scratch4, scratch5, scratch6, scratch7);
return kTfLiteOk;
}
}
}
default:

View File

@ -1,4 +1,4 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -816,7 +816,7 @@ inline void LstmStepHybrid(
}
}
// Fully quantized lstm kernel. Currently supports both cifg and non-cifg.
// Fully quantized lstm kernel for 16 bit gate matmul output.
//
// Input activation of size n_batch * n_input:
// input_ptr
@ -895,7 +895,7 @@ inline void LstmStepHybrid(
//
// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
// n_batch.
// scratch_0:
// scratch_0
// scratch_1
// scratch_2
// scratch_3
@ -1142,6 +1142,272 @@ inline void LstmStepInteger(
std::copy_n(output_ptr, n_batch * n_output, activation_ptr);
}
// Fully quantized lstm kernel for 8 bit gate matmul output.
//
// Input activation of size n_batch * n_input:
// input_ptr
//
// LSTM weights:
// Quantized input weights of size 'n_cell * n_input':
// input_to_input_weight_ptr - optional
// input_to_forget_weight_ptr - optional
// input_to_cell_weight_ptr - optional
// input_to_output_weight_ptr - optional
//
// Quantized recurrent weights of size 'n_cell * n_output':
// recurrent_to_input_weight_ptr - optional
// recurrent_to_forget_weights_ptr
// recurrent_to_cell_weights_ptr
// recurrent_to_input_weights_ptr
//
// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
// cell_to_input_weights - optional
// cell_to_cell_weights - optional
// cell_to_output_weights - optional
//
// Quantized projection weights of size 'n_output * n_cell'
// proj_weight_ptr - optional
//
// Weight scales (scalars) for each of the weights above.
// effective_input_to_input_scale_a - optional
// effective_input_to_input_scale_b - optional
// effective_input_to_forget_scale_a
// effective_input_to_forget_scale_b
// effective_input_to_cell_scale_a
// effective_input_to_cell_scale_b
// effective_input_to_output_scale_a
// effective_input_to_output_scale_b
// effective_recurrent_to_input_scale_a - optional
// effective_recurrent_to_input_scale_b - optional
// effective_recurrent_to_forget_scale_a
// effective_recurrent_to_forget_scale_b
// effective_recurrent_to_cell_scale_a
// effective_recurrent_to_cell_scale_b
// effective_recurrent_to_output_scale_a
// effective_recurrent_to_output_scale_b
// effective_proj_scale_a - optional
// effective_proj_scale_b - optional
//
// Gate biases of size 'n_cell':
// input_bias_ptr - optional
// forget_bias_ptr
// cell_bias_ptr
// output_bias_ptr
//
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
// layer_norm_input_weight_ptr - optional
// layer_norm_forput_weight_ptr - optional
// layer_norm_cell_weight_ptr - optional
// layer_norm_output_weight_ptr - optional
//
// Layer norm scales of size 'n_cell'.
// layer_norm_input_scale_a - optional
// layer_norm_input_scale_b - optional
// layer_norm_forget_scale_a - optional
// layer_norm_forget_scale_b - optional
// layer_norm_cell_scale_a - optional
// layer_norm_cell_scale_b - optional
// layer_norm_output_scale_a - optional
// layer_norm_output_scale_b - optional
//
// Scalar values:
// quantized_cell_clip: quantized clip value for cell.
// quantized_proj_clip: quantized clip value for projection.
// cell_scale: the power of two scale for cell state.
//
// Zero points:
// activation_zp: zero point of activation
// hidden_zp: zero point for hidden state.
//
// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
// n_batch.
// scratch_0
// scratch_1
// scratch_2
// scratch_3
// scratch_4
// scratch_5
// scratch_6
// scratch_7
//
// Outputs:
// output_state_ptr - size 'n_batch * n_output'
// cell_state_ptr - size 'n_batch * n_cell'
// output_ptr - size 'n_batch * n_output'
// TODO(b/148688698): Move zero point calculation into Prepare().
void LstmStepInteger(
const int8_t* input_ptr, int32_t input_zp,
const int8_t* input_to_input_weight_ptr,
int32_t effective_input_to_input_scale_a,
int32_t effective_input_to_input_scale_b,
const int8_t* input_to_forget_weight_ptr,
int32_t effective_input_to_forget_scale_a,
int32_t effective_input_to_forget_scale_b,
const int8_t* input_to_cell_weight_ptr,
int32_t effective_input_to_cell_scale_a,
int32_t effective_input_to_cell_scale_b,
const int8_t* input_to_output_weight_ptr,
int32_t effective_input_to_output_scale_a,
int32_t effective_input_to_output_scale_b,
const int8_t* recurrent_to_input_weight_ptr,
int32_t effective_recurrent_to_input_scale_a,
int32_t effective_recurrent_to_input_scale_b,
const int8_t* recurrent_to_forget_weight_ptr,
int32_t effective_recurrent_to_forget_scale_a,
int32_t effective_recurrent_to_forget_scale_b,
const int8_t* recurrent_to_cell_weight_ptr,
int32_t effective_recurrent_to_cell_scale_a,
int32_t effective_recurrent_to_cell_scale_b,
const int8_t* recurrent_to_output_weight_ptr,
int32_t effective_recurrent_to_output_scale_a,
int32_t effective_recurrent_to_output_scale_b,
const int8_t* cell_to_input_weight_ptr,
int32_t effective_cell_to_input_scale_a,
int32_t effective_cell_to_input_scale_b,
const int8_t* cell_to_forget_weight_ptr,
int32_t effective_cell_to_forget_scale_a,
int32_t effective_cell_to_forget_scale_b,
const int8_t* cell_to_output_weight_ptr,
int32_t effective_cell_to_output_scale_a,
int32_t effective_cell_to_output_scale_b, const int8_t* proj_weight_ptr,
int32_t effective_proj_scale_a, int32_t effective_proj_scale_b,
const int16_t* layer_norm_input_weight_ptr,
int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
const int16_t* layer_norm_forget_weight_ptr,
int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
int32_t layer_norm_cell_scale_b,
const int16_t* layer_norm_output_weight_ptr,
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr,
const int32_t* cell_bias_ptr, const int32_t* output_bias_ptr,
const int32_t* proj_bias_ptr, const TfLiteLSTMParams* params,
const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
const int32_t* intermediate_zp, int32 quantized_cell_clip,
int32 quantized_proj_clip, int32 n_batch, int32 n_cell, int32 n_input,
int32 n_output, int32 output_batch_leading_dim, int8_t* activation_ptr,
int32_t activation_zp, int16_t* cell_ptr, int8_t* output_ptr,
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
int16_t* scratch7) {
// Forget gate.
memset(scratch0, 0, n_batch * n_cell);
memset(scratch1, 0, n_batch * n_cell);
tensor_utils::MatrixBatchVectorMultiply(
input_ptr, input_zp, input_to_forget_weight_ptr,
effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
n_batch, n_input, n_cell, scratch0, intermediate_zp[4]);
tensor_utils::MatrixBatchVectorMultiply(
activation_ptr, activation_zp, recurrent_to_forget_weight_ptr,
effective_recurrent_to_forget_scale_a,
effective_recurrent_to_forget_scale_b, n_batch, n_output, n_cell,
scratch1, intermediate_zp[5]);
tensor_utils::TwoGateSaturationgAdd(
scratch0, intermediate_zp[4], scratch1, intermediate_zp[5],
intermediate_scale_a[2], intermediate_scale_b[2], intermediate_scale_a[3],
intermediate_scale_b[3], n_batch, n_cell, scratch2);
// Forget gate layer norm.
tensor_utils::ApplyLayerNormFloat(
scratch2, layer_norm_forget_weight_ptr, layer_norm_forget_scale_a,
layer_norm_forget_scale_b, forget_bias_ptr, n_batch, n_cell, scratch2);
// Forget gate sigmoid.
tensor_utils::ApplySigmoidFloat(scratch2, n_batch, n_cell, scratch2);
// Update gate.
memset(scratch0, 0, n_batch * n_cell);
memset(scratch1, 0, n_batch * n_cell);
tensor_utils::MatrixBatchVectorMultiply(
input_ptr, input_zp, input_to_cell_weight_ptr,
effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, n_batch,
n_input, n_cell, scratch0, intermediate_zp[7]);
tensor_utils::MatrixBatchVectorMultiply(
activation_ptr, activation_zp, recurrent_to_cell_weight_ptr,
effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
n_batch, n_output, n_cell, scratch1, intermediate_zp[8]);
tensor_utils::TwoGateSaturationgAdd(
scratch0, intermediate_zp[7], scratch1, intermediate_zp[8],
intermediate_scale_a[4], intermediate_scale_b[4], intermediate_scale_a[5],
intermediate_scale_b[5], n_batch, n_cell, scratch3);
// Update gate with layer norm.
tensor_utils::ApplyLayerNormFloat(
scratch3, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
layer_norm_cell_scale_b, cell_bias_ptr, n_batch, n_cell, scratch3);
// Update gate tanh.
tensor_utils::ApplyTanhFloat(scratch3, n_batch, n_cell, -12, scratch3);
// Output gate.
memset(scratch0, 0, n_batch * n_cell);
memset(scratch1, 0, n_batch * n_cell);
tensor_utils::MatrixBatchVectorMultiply(
input_ptr, input_zp, input_to_output_weight_ptr,
effective_input_to_output_scale_a, effective_input_to_output_scale_b,
n_batch, n_input, n_cell, scratch0, intermediate_zp[10]);
tensor_utils::MatrixBatchVectorMultiply(
activation_ptr, activation_zp, recurrent_to_output_weight_ptr,
effective_recurrent_to_output_scale_a,
effective_recurrent_to_output_scale_b, n_batch, n_output, n_cell,
scratch1, intermediate_zp[11]);
tensor_utils::TwoGateSaturationgAdd(
scratch0, intermediate_zp[10], scratch1, intermediate_zp[11],
intermediate_scale_a[6], intermediate_scale_b[6], intermediate_scale_a[7],
intermediate_scale_b[7], n_batch, n_cell, scratch4);
// Output gate with layer norm.
tensor_utils::ApplyLayerNormFloat(
scratch4, layer_norm_output_weight_ptr, layer_norm_output_scale_a,
layer_norm_output_scale_b, output_bias_ptr, n_batch, n_cell, scratch4);
// Output gate sigmoid.
tensor_utils::ApplySigmoidFloat(scratch4, n_batch, n_cell, scratch4);
// Input gate with cifg
tensor_utils::Sub1Vector(scratch2, n_batch * n_cell, scratch5);
// New cell.
tensor_utils::CwiseMul(scratch2, cell_ptr, n_batch, n_cell, 15 + 15 - 15,
scratch6);
tensor_utils::CwiseMul(scratch5, scratch3, n_batch, n_cell, 15 + 15 - 15,
scratch7);
tensor_utils::CwiseAdd(scratch6, scratch7, n_batch, n_cell, cell_ptr);
if (quantized_cell_clip > 0) {
tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell);
}
// Cell to hidden.
tensor_utils::ApplyTanhFloat(cell_ptr, n_batch, n_cell, -15, scratch2);
std::vector<int16_t> hidden(n_batch * n_cell);
tensor_utils::CwiseMul(scratch4, scratch2, n_batch, n_cell, 15 + 15 - 15,
scratch3);
// Projection.
tensor_utils::MatrixBatchVectorMultiply(
scratch3, proj_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
proj_bias_ptr, n_batch, n_cell, n_output, activation_zp, output_ptr);
// Projection clipping.
if (quantized_proj_clip > 0) {
tensor_utils::CwiseClipping(output_ptr, quantized_proj_clip, n_batch,
n_output);
}
// Copy output to activation.
memcpy(activation_ptr, output_ptr, n_batch * n_output * sizeof(int8_t));
}
} // namespace
// LINT.IfChange
@ -1692,6 +1958,186 @@ TfLiteStatus EvalInteger8x8_16(
return kTfLiteOk;
}
TfLiteStatus EvalInteger8x8_8(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* activation_state,
TfLiteTensor* cell_state, TfLiteTensor* output,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
TfLiteTensor* scratch6, TfLiteTensor* scratch7) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int n_input = input->dims->data[input->dims->size - 1];
int max_time, n_batch;
if (input->dims->size == 2) {
max_time = 1;
n_batch = input->dims->data[0];
} else {
max_time = input->dims->data[0];
n_batch = input->dims->data[1];
}
// n_cell and n_output will be the same size when there is no projection.
const int n_cell = input_to_output_weights->dims->data[0];
const int n_output = recurrent_to_output_weights->dims->data[1];
// Weights and states.
const int8_t* input_to_input_weight_ptr =
GetTensorData<int8_t>(input_to_input_weights);
const int8_t* recurrent_to_input_weight_ptr =
GetTensorData<int8_t>(recurrent_to_input_weights);
const int8_t* cell_to_input_weight_ptr =
GetTensorData<int8_t>(cell_to_input_weights);
const int8_t* input_to_forget_weight_ptr =
GetTensorData<int8_t>(input_to_forget_weights);
const int8_t* recurrent_to_forget_weight_ptr =
GetTensorData<int8_t>(recurrent_to_forget_weights);
const int8_t* cell_to_forget_weight_ptr =
GetTensorData<int8_t>(cell_to_forget_weights);
const int8_t* input_to_cell_weight_ptr =
GetTensorData<int8_t>(input_to_cell_weights);
const int8_t* recurrent_to_cell_weight_ptr =
GetTensorData<int8_t>(recurrent_to_cell_weights);
const int8_t* input_to_output_weight_ptr =
GetTensorData<int8_t>(input_to_output_weights);
const int8_t* recurrent_to_output_weight_ptr =
GetTensorData<int8_t>(recurrent_to_output_weights);
const int8_t* cell_to_output_weight_ptr =
GetTensorData<int8_t>(cell_to_output_weights);
const int8_t* proj_weight_ptr = GetTensorData<int8_t>(projection_weights);
const int16_t* layer_norm_input_weight_ptr =
GetTensorData<int16_t>(input_layer_norm_coefficients);
const int16_t* layer_norm_forget_weight_ptr =
GetTensorData<int16_t>(forget_layer_norm_coefficients);
const int16_t* layer_norm_cell_weight_ptr =
GetTensorData<int16_t>(cell_layer_norm_coefficients);
const int16_t* layer_norm_output_weight_ptr =
GetTensorData<int16_t>(output_layer_norm_coefficients);
const int32_t* input_bias_ptr = GetTensorData<int32_t>(input_gate_bias);
const int32_t* forget_bias_ptr = GetTensorData<int32_t>(forget_gate_bias);
const int32_t* cell_bias_ptr = GetTensorData<int32_t>(cell_bias);
const int32_t* output_bias_ptr = GetTensorData<int32_t>(output_gate_bias);
const int32_t* proj_bias_ptr = GetTensorData<int32_t>(projection_bias);
int16_t* cell_ptr = GetTensorData<int16_t>(cell_state);
int8_t* activation_ptr = GetTensorData<int8_t>(activation_state);
int8_t* output_ptr = nullptr;
const int32 input_zp = input->params.zero_point;
const int32 activation_zp = activation_state->params.zero_point;
// Get params for time/batch/sequence.
const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1];
const int input_step = n_batch * n_input;
const int output_step = n_batch * output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
const int t_rel = t;
output_ptr = output->data.int8 + t_rel * output_step;
// Input can be int8 asymmetric or int16 symmetric.
const int8_t* input_ptr = input->data.int8 + t_rel * input_step;
lstm_eval::LstmStepInteger(
input_ptr, input_zp,
input_to_input_weight_ptr,
integer_lstm_param->effective_input_to_input_scale_a,
integer_lstm_param->effective_input_to_input_scale_b,
input_to_forget_weight_ptr,
integer_lstm_param->effective_input_to_forget_scale_a,
integer_lstm_param->effective_input_to_forget_scale_b,
input_to_cell_weight_ptr,
integer_lstm_param->effective_input_to_cell_scale_a,
integer_lstm_param->effective_input_to_cell_scale_b,
input_to_output_weight_ptr,
integer_lstm_param->effective_input_to_output_scale_a,
integer_lstm_param->effective_input_to_output_scale_b,
recurrent_to_input_weight_ptr,
integer_lstm_param->effective_recurrent_to_input_scale_a,
integer_lstm_param->effective_recurrent_to_input_scale_b,
recurrent_to_forget_weight_ptr,
integer_lstm_param->effective_recurrent_to_forget_scale_a,
integer_lstm_param->effective_recurrent_to_forget_scale_b,
recurrent_to_cell_weight_ptr,
integer_lstm_param->effective_recurrent_to_cell_scale_a,
integer_lstm_param->effective_recurrent_to_cell_scale_b,
recurrent_to_output_weight_ptr,
integer_lstm_param->effective_recurrent_to_output_scale_a,
integer_lstm_param->effective_recurrent_to_output_scale_b,
cell_to_input_weight_ptr,
integer_lstm_param->effective_cell_to_input_scale_a,
integer_lstm_param->effective_cell_to_input_scale_b,
cell_to_forget_weight_ptr,
integer_lstm_param->effective_cell_to_forget_scale_a,
integer_lstm_param->effective_cell_to_forget_scale_b,
cell_to_output_weight_ptr,
integer_lstm_param->effective_cell_to_output_scale_a,
integer_lstm_param->effective_cell_to_output_scale_b,
proj_weight_ptr, integer_lstm_param->effective_proj_scale_a,
integer_lstm_param->effective_proj_scale_b,
layer_norm_input_weight_ptr,
integer_lstm_param->layer_norm_input_scale_a,
integer_lstm_param->layer_norm_input_scale_b,
layer_norm_forget_weight_ptr,
integer_lstm_param->layer_norm_forget_scale_a,
integer_lstm_param->layer_norm_forget_scale_b,
layer_norm_cell_weight_ptr, integer_lstm_param->layer_norm_cell_scale_a,
integer_lstm_param->layer_norm_cell_scale_b,
layer_norm_output_weight_ptr,
integer_lstm_param->layer_norm_output_scale_a,
integer_lstm_param->layer_norm_output_scale_b,
input_bias_ptr, forget_bias_ptr, cell_bias_ptr, output_bias_ptr,
proj_bias_ptr,
params, integer_lstm_param->intermediate_scale_a,
integer_lstm_param->intermediate_scale_b,
integer_lstm_param->intermediate_zp,
integer_lstm_param->quantized_cell_clip,
integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
n_output, output_batch_leading_dim, activation_ptr, activation_zp,
cell_ptr, output_ptr, GetTensorData<int8_t>(scratch0),
GetTensorData<int8_t>(scratch1), GetTensorData<int16_t>(scratch2),
GetTensorData<int16_t>(scratch3), GetTensorData<int16_t>(scratch4),
GetTensorData<int16_t>(scratch5), GetTensorData<int16_t>(scratch6),
GetTensorData<int16_t>(scratch7));
}
return kTfLiteOk;
}
} // namespace lstm_eval
} // namespace builtin
} // namespace ops

View File

@ -1,4 +1,4 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -28,7 +28,8 @@ namespace ops {
namespace builtin {
namespace lstm_eval {
// Pamameters for quantized lstm.
// Pamameters for integer LSTM.
// Consider split this into two Integer Parameters if more fields are added.
struct IntegerLstmParameter {
int32_t effective_input_to_input_scale_a;
int32_t effective_input_to_input_scale_b;
@ -75,24 +76,24 @@ struct IntegerLstmParameter {
int32_t cell_variance_guard;
int32_t output_variance_guard;
// The fields are used for pre-computing zero_point * weight.
// We cannot use temporary tensors since temporary tensors are not alllocated
// yet until end of prepare.
// Forget gate.
// Pre-calculate bias + zero_point * weight.
// Unabled to use temporary tensors since those are used in Prepare() and
// scratch buffer is only allocated after Preapre().
std::unique_ptr<int32_t[]> input_to_forget_effective_bias;
std::unique_ptr<int32_t[]> recurrent_to_forget_effective_bias;
// Modulation gate.
std::unique_ptr<int32_t[]> input_to_cell_effective_bias;
std::unique_ptr<int32_t[]> recurrent_to_cell_effective_bias;
// Output gate.
std::unique_ptr<int32_t[]> input_to_output_effective_bias;
std::unique_ptr<int32_t[]> recurrent_to_output_effective_bias;
// Input gate.
std::unique_ptr<int32_t[]> input_to_input_effective_bias;
std::unique_ptr<int32_t[]> recurrent_to_input_effective_bias;
// Projection.
std::unique_ptr<int32_t[]> projection_effective_bias;
// Scale and zero point for intermediate tensors.
// Used only in the 8x8_8 case.
int32_t intermediate_scale_a[8];
int32_t intermediate_scale_b[8];
int32_t intermediate_zp[12];
};
TfLiteStatus EvalFloat(
@ -183,6 +184,32 @@ TfLiteStatus EvalInteger8x8_16(
TfLiteTensor* scratch2, TfLiteTensor* scratch3, TfLiteTensor* scratch4,
TfLiteTensor* scratch5, CpuBackendContext* context);
TfLiteStatus EvalInteger8x8_8(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* activation_state,
TfLiteTensor* cell_state, TfLiteTensor* output,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
TfLiteTensor* scratch6, TfLiteTensor* scratch7);
} // namespace lstm_eval
} // namespace builtin
} // namespace ops

View File

@ -2756,6 +2756,483 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionYesPeephole) {
}
}
class LSTMIntegerOpModel8x8_8 : public SingleOpModel {
public:
LSTMIntegerOpModel8x8_8(
int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
bool use_peephole, bool use_projection_weights, bool use_projection_bias,
bool use_layer_norm, float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes,
const std::vector<std::pair<float, float>>& ranges,
const std::vector<std::pair<float, int>>& intermediates)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
n_output_(n_output) {
EXPECT_EQ(input_shapes.size() + 1, ranges.size());
EXPECT_EQ(intermediates.size(), 12);
input_ = AddInput(
{TensorType_INT8, input_shapes[0], ranges[0].first, ranges[0].second});
if (use_cifg) {
input_to_input_weights_ = AddNullInput();
} else {
input_to_input_weights_ = AddInput({TensorType_INT8, input_shapes[1],
ranges[1].first, ranges[1].second});
}
input_to_forget_weights_ = AddInput(
{TensorType_INT8, input_shapes[2], ranges[2].first, ranges[2].second});
input_to_cell_weights_ = AddInput(
{TensorType_INT8, input_shapes[3], ranges[3].first, ranges[3].second});
input_to_output_weights_ = AddInput(
{TensorType_INT8, input_shapes[4], ranges[4].first, ranges[4].second});
if (use_cifg) {
recurrent_to_input_weights_ = AddNullInput();
} else {
recurrent_to_input_weights_ =
AddInput({TensorType_INT8, input_shapes[5], ranges[5].first,
ranges[5].second});
}
recurrent_to_forget_weights_ = AddInput(
{TensorType_INT8, input_shapes[6], ranges[6].first, ranges[6].second});
recurrent_to_cell_weights_ = AddInput(
{TensorType_INT8, input_shapes[7], ranges[7].first, ranges[7].second});
recurrent_to_output_weights_ = AddInput(
{TensorType_INT8, input_shapes[8], ranges[8].first, ranges[8].second});
if (use_peephole) {
if (use_cifg) {
cell_to_input_weights_ = AddNullInput();
} else {
cell_to_input_weights_ = AddInput({TensorType_INT16, input_shapes[9],
ranges[9].first, ranges[9].second});
}
cell_to_forget_weights_ = AddInput({TensorType_INT16, input_shapes[10],
ranges[10].first, ranges[10].second});
cell_to_output_weights_ = AddInput({TensorType_INT16, input_shapes[11],
ranges[11].first, ranges[11].second});
} else {
cell_to_input_weights_ = AddNullInput();
cell_to_forget_weights_ = AddNullInput();
cell_to_output_weights_ = AddNullInput();
}
if (use_cifg) {
input_gate_bias_ = AddNullInput();
} else {
input_gate_bias_ = AddInput({TensorType_INT32, input_shapes[12],
ranges[12].first, ranges[12].second});
}
forget_gate_bias_ = AddInput({TensorType_INT32, input_shapes[13],
ranges[13].first, ranges[13].second});
cell_bias_ = AddInput({TensorType_INT32, input_shapes[14], ranges[14].first,
ranges[14].second});
output_gate_bias_ = AddInput({TensorType_INT32, input_shapes[15],
ranges[15].first, ranges[15].second});
if (use_projection_weights) {
projection_weights_ = AddInput({TensorType_INT8, input_shapes[16],
ranges[16].first, ranges[16].second});
if (use_projection_bias) {
projection_bias_ = AddInput({TensorType_INT32, input_shapes[17],
ranges[17].first, ranges[17].second});
} else {
projection_bias_ = AddNullInput();
}
} else {
projection_weights_ = AddNullInput();
projection_bias_ = AddNullInput();
}
// Adding the 2 input state tensors.
input_activation_state_ = AddInput({TensorType_INT16, input_shapes[18],
ranges[18].first, ranges[18].second},
true);
input_cell_state_ = AddInput({TensorType_INT16, input_shapes[19],
ranges[19].first, ranges[19].second},
true);
// Layer norm weights.
if (use_layer_norm) {
if (use_cifg) {
input_layer_norm_coefficients_ = AddNullInput();
} else {
input_layer_norm_coefficients_ =
AddInput({TensorType_INT16, input_shapes[20], ranges[20].first,
ranges[20].second});
}
forget_layer_norm_coefficients_ =
AddInput({TensorType_INT16, input_shapes[21], ranges[21].first,
ranges[21].second});
cell_layer_norm_coefficients_ =
AddInput({TensorType_INT16, input_shapes[22], ranges[22].first,
ranges[22].second});
output_layer_norm_coefficients_ =
AddInput({TensorType_INT16, input_shapes[23], ranges[23].first,
ranges[23].second});
}
for (int i = 0; i < intermediates.size(); ++i) {
intermediates_[i] =
AddIntermediate(TensorType_INT16, {intermediates[i].first},
{intermediates[i].second});
}
output_ = AddOutput({TensorType_INT8,
{n_batch, n_output},
ranges[24].first,
ranges[24].second});
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
cell_clip, proj_clip)
.Union());
// Do not apply delegate yet since tensor values are not known (and more
// specifically scales in quantized tensors are not known).
BuildInterpreter(input_shapes, /*allow_fp32_relax_to_fp16=*/false,
/*apply_delegate=*/false);
}
void SetInputToInputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(input_to_input_weights_, f);
}
void SetInputToForgetWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(input_to_forget_weights_, f);
}
void SetInputToCellWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(input_to_cell_weights_, f);
}
void SetInputToOutputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(input_to_output_weights_, f);
}
void SetRecurrentToInputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(recurrent_to_input_weights_, f);
}
void SetRecurrentToForgetWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(recurrent_to_forget_weights_, f);
}
void SetRecurrentToCellWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(recurrent_to_cell_weights_, f);
}
void SetRecurrentToOutputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(recurrent_to_output_weights_, f);
}
void SetCellToInputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(cell_to_input_weights_, f);
}
void SetCellToForgetWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(cell_to_forget_weights_, f);
}
void SetCellToOutputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(cell_to_output_weights_, f);
}
void SetInputLayerNormCoefficients(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(input_layer_norm_coefficients_, f);
}
void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(forget_layer_norm_coefficients_, f);
}
void SetCellLayerNormCoefficients(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(cell_layer_norm_coefficients_, f);
}
void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(output_layer_norm_coefficients_, f);
}
void SetInputGateBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(input_gate_bias_, f);
}
void SetForgetGateBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(forget_gate_bias_, f);
}
void SetCellBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(cell_bias_, f);
}
void SetOutputGateBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(output_gate_bias_, f);
}
void SetProjectionWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(projection_weights_, f);
}
void SetProjectionBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(projection_bias_, f);
}
void SetInput(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(input_, f);
}
std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
int num_inputs() { return n_input_; }
int num_outputs() { return n_output_; }
int num_cells() { return n_cell_; }
int num_batches() { return n_batch_; }
protected:
int input_;
int input_to_input_weights_;
int input_to_forget_weights_;
int input_to_cell_weights_;
int input_to_output_weights_;
int recurrent_to_input_weights_;
int recurrent_to_forget_weights_;
int recurrent_to_cell_weights_;
int recurrent_to_output_weights_;
int cell_to_input_weights_;
int cell_to_forget_weights_;
int cell_to_output_weights_;
int input_layer_norm_coefficients_;
int forget_layer_norm_coefficients_;
int cell_layer_norm_coefficients_;
int output_layer_norm_coefficients_;
int input_gate_bias_;
int forget_gate_bias_;
int cell_bias_;
int output_gate_bias_;
int projection_weights_;
int projection_bias_;
int input_activation_state_;
int input_cell_state_;
int intermediates_[12];
int output_;
int output_state_;
int cell_state_;
int n_batch_;
int n_input_;
int n_cell_;
int n_output_;
};
TEST(LSTMIntegerOpModel8x8_8, CifgYesLayerNormNoYesProjectionNoPeephole) {
// Hyper parameters.
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 4;
const int n_output = 3;
const float cell_clip = 0.0;
const float proj_clip = 0.0;
// Model related weights.
const std::vector<float> input_to_input_weights = {
0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
-0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
const std::vector<float> input_to_forget_weights = {
-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
-0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
const std::vector<float> input_to_cell_weights = {
-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
const std::vector<float> input_to_output_weights = {
-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
const std::vector<float> recurrent_to_input_weights = {
-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
const std::vector<float> recurrent_to_cell_weights = {
-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
const std::vector<float> recurrent_to_forget_weights = {
-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
const std::vector<float> recurrent_to_output_weights = {
0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
0.3};
const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
0.5};
const std::vector<float> projection_weights = {
-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
const std::vector<float> projection_bias = {0.1, 0.3, 0.5};
// Input shapes.
const std::vector<std::vector<int32_t>> inputs = {
{n_batch, n_input}, // input tensor
{0}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{0}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{0}, // cell_to_input_weight tensor
{0}, // cell_to_forget_weight tensor
{0}, // cell_to_output_weight tensor
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
{n_output}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
{n_cell}, // forget_layer_norm_coefficient tensor
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
};
// Input ranges.
const std::vector<std::pair<float, float>> ranges = {
{-1.0, 127.0 / 128}, // input tensor
{-1.0, 1.0}, // input_to_input_weight tensor
{-1.0, 1.0}, // input_to_forget_weight tensor
{-1.0, 1.0}, // input_to_cell_weight tensor
{-1.0, 1.0}, // input_to_output_weight tensor
{-1.0, 1.0}, // recurrent_to_input_weight tensor
{-1.0, 1.0}, // recurrent_to_forget_weight tensor
{-1.0, 1.0}, // recurrent_to_cell_weight tensor
{-1.0, 1.0}, // recurrent_to_output_weight tensor
{-1, 1}, // cell_to_input_weight tensor
{-1, 1}, // cell_to_forget_weight tensor
{-1, 1}, // cell_to_output_weight tensor
{-100, 100}, // input_gate_bias tensor
{-100, 100}, // forget_gate_bias tensor
{-100, 100}, // cell_bias tensor
{-100, 100}, // output_gate_bias tensor
{-0.5, 0.5}, // projection_weight tensor
{-1, 1}, // projection_bias tensor
{-1.0, 32767.0 / 32768}, // activation_state tensor
{-1.0, 32767.0 / 32768}, // cell_state tensor
{-1.00001, 1.0}, // input_layer_norm_coefficient tensor
{-1.00001, 1.0}, // forget_layer_norm_coefficient tensor
{-1.00001, 1.0}, // cell_layer_norm_coefficient tensor
{-1.00001, 1.0}, // output_layer_norm_coefficient tensor
// Output scale is the same as input activation scale and only activation
// scale is used in the op, so this is only provided for clarity.
{-1.0, 32767.0 / 32768}, // output tensor.
};
// The scale and zero point of intermediate tensors.
std::vector<std::pair<float, int>> intermediates = {
{0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0},
{0.007, 0}, {0.007059, 0}, {0.007, 0}, {0.007, 0},
{0.007059, 0}, {0.007, 0}, {0.007, 0}, {0.3, 0}};
// Create model.
LSTMIntegerOpModel8x8_8 lstm(n_batch, n_input, n_cell, n_output,
/*use_cifg=*/true, /*use_peephole=*/false,
/*use_projection_weights=*/true,
/*use_projection_bias=*/true,
/*use_layer_norm=*/true, cell_clip, proj_clip,
inputs, ranges, intermediates);
// Set weights.
// lstm.SetInputToInputWeights(input_to_input_weights);
lstm.SetInputToCellWeights(input_to_cell_weights);
lstm.SetInputToForgetWeights(input_to_forget_weights);
lstm.SetInputToOutputWeights(input_to_output_weights);
// lstm.SetInputGateBias(input_gate_bias);
lstm.SetCellBias(cell_gate_bias);
lstm.SetForgetGateBias(forget_gate_bias);
lstm.SetOutputGateBias(output_gate_bias);
// lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
lstm.SetProjectionWeights(projection_weights);
lstm.SetProjectionBias(projection_bias);
// lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
// Model inputs. sequence -batch - input
const std::vector<std::vector<float>> lstm_input = {
{
0.7, 0.8, 0.1, 0.2, 0.3, //
0.8, 0.1, 0.2, 0.4, 0.5, //
},
{
0.2, 0.7, 0.7, 0.1, 0.7, //
0.3, 0.2, 0.9, 0.8, 0.1, //
},
{
0.7, 0.8, 0.1, 0.2, 0.3, //
0.3, 0.2, 0.9, 0.8, 0.1, //
},
};
// Expected outputs.
const std::vector<std::vector<int8_t>> expected_output = {
{127, 127, 127, 127, 127, 127},
{127, 127, 127, 127, 127, 127},
{127, 127, 127, 127, 127, 127},
};
// Invoke and verify the result.
const int input_sequence_size = lstm_input.size();
EXPECT_GT(input_sequence_size, 0);
for (int i = 0; i < input_sequence_size; ++i) {
lstm.SetInput(lstm_input[i]);
lstm.Invoke();
EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output[i]));
}
}
#ifdef GTEST_HAS_DEATH_TEST
TEST(LSTMOpModel, InvalidTypeTest) {
const int n_batch = 1;