Remove normalization_epsilon parameter of MeanStddevNormalization; change it to a compile-time constant.
PiperOrigin-RevId: 279430358 Change-Id: I25aee6a065617e94b5b8d0a20f9cb2d3dce62314
This commit is contained in:
parent
43d77b42e7
commit
7e48fd7fcf
@ -254,10 +254,8 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
||||||
int v_size, int n_batch,
|
int v_size, int n_batch) {
|
||||||
float normalization_epsilon) {
|
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
||||||
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
|
|
||||||
normalization_epsilon);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensor_utils
|
} // namespace tensor_utils
|
||||||
|
@ -264,10 +264,8 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
||||||
int v_size, int n_batch,
|
int v_size, int n_batch) {
|
||||||
float normalization_epsilon) {
|
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
||||||
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
|
|
||||||
normalization_epsilon);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensor_utils
|
} // namespace tensor_utils
|
||||||
|
@ -627,7 +627,7 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
|
|||||||
|
|
||||||
void PortableMeanStddevNormalization(const float* input_vector,
|
void PortableMeanStddevNormalization(const float* input_vector,
|
||||||
float* output_vector, int v_size,
|
float* output_vector, int v_size,
|
||||||
int n_batch, float normalization_epsilon) {
|
int n_batch) {
|
||||||
for (int batch = 0; batch < n_batch; ++batch) {
|
for (int batch = 0; batch < n_batch; ++batch) {
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
float sum_sq = 0.0f;
|
float sum_sq = 0.0f;
|
||||||
@ -639,7 +639,8 @@ void PortableMeanStddevNormalization(const float* input_vector,
|
|||||||
float stddev_inv = 0.0f;
|
float stddev_inv = 0.0f;
|
||||||
const float variance = sum_sq / v_size - mean * mean;
|
const float variance = sum_sq / v_size - mean * mean;
|
||||||
if (variance == 0) {
|
if (variance == 0) {
|
||||||
stddev_inv = 1.0f / std::sqrt(normalization_epsilon);
|
constexpr float kNormalizationConstant = 1e-8;
|
||||||
|
stddev_inv = 1.0f / std::sqrt(kNormalizationConstant);
|
||||||
} else {
|
} else {
|
||||||
stddev_inv = 1.0f / std::sqrt(variance);
|
stddev_inv = 1.0f / std::sqrt(variance);
|
||||||
}
|
}
|
||||||
|
@ -258,10 +258,8 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
||||||
int v_size, int n_batch,
|
int v_size, int n_batch) {
|
||||||
float normalization_epsilon) {
|
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
||||||
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
|
|
||||||
normalization_epsilon);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensor_utils
|
} // namespace tensor_utils
|
||||||
|
@ -200,10 +200,9 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
|
|||||||
int output_size, int reduction_size);
|
int output_size, int reduction_size);
|
||||||
|
|
||||||
// Layer norm for each batch.
|
// Layer norm for each batch.
|
||||||
// normalization_epsilon is added to avoid divergence.
|
|
||||||
void PortableMeanStddevNormalization(const float* input_vector,
|
void PortableMeanStddevNormalization(const float* input_vector,
|
||||||
float* output_vector, int v_size,
|
float* output_vector, int v_size,
|
||||||
int n_batch, float normalization_epsilon);
|
int n_batch);
|
||||||
|
|
||||||
} // namespace tensor_utils
|
} // namespace tensor_utils
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -439,10 +439,8 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
|
|||||||
int output_size, int reduction_size);
|
int output_size, int reduction_size);
|
||||||
|
|
||||||
// Layer norm for each batch.
|
// Layer norm for each batch.
|
||||||
// normalization_epsilon is added to avoid divergence.
|
|
||||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
||||||
int v_size, int n_batch,
|
int v_size, int n_batch);
|
||||||
float normalization_epsilon);
|
|
||||||
} // namespace tensor_utils
|
} // namespace tensor_utils
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -1466,7 +1466,6 @@ TEST(uKernels, ReductionSumVectorTest) {
|
|||||||
TEST(uKernels, MeanStddevNormalization) {
|
TEST(uKernels, MeanStddevNormalization) {
|
||||||
constexpr int kVectorSize = 4;
|
constexpr int kVectorSize = 4;
|
||||||
constexpr int kBatchSize = 8; // 9, but large mean, small variance fails
|
constexpr int kBatchSize = 8; // 9, but large mean, small variance fails
|
||||||
constexpr float kNormalizationEpsilon = 1e-8;
|
|
||||||
|
|
||||||
// None-zero input.
|
// None-zero input.
|
||||||
static float input[kVectorSize * kBatchSize] = {
|
static float input[kVectorSize * kBatchSize] = {
|
||||||
@ -1480,8 +1479,7 @@ TEST(uKernels, MeanStddevNormalization) {
|
|||||||
-100.0f, 0.0f, 200.0f, 300.0f, // large mean, large variance
|
-100.0f, 0.0f, 200.0f, 300.0f, // large mean, large variance
|
||||||
};
|
};
|
||||||
float output[kVectorSize * kBatchSize];
|
float output[kVectorSize * kBatchSize];
|
||||||
MeanStddevNormalization(input, output, kVectorSize, kBatchSize,
|
MeanStddevNormalization(input, output, kVectorSize, kBatchSize);
|
||||||
kNormalizationEpsilon);
|
|
||||||
const float ksqrt16 = std::sqrt(1.6f);
|
const float ksqrt16 = std::sqrt(1.6f);
|
||||||
const float ksqrt04 = std::sqrt(0.4f);
|
const float ksqrt04 = std::sqrt(0.4f);
|
||||||
const std::vector<float> expected_output = {
|
const std::vector<float> expected_output = {
|
||||||
|
@ -38,11 +38,6 @@ namespace builtin {
|
|||||||
namespace lstm_eval {
|
namespace lstm_eval {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Small float to avoid divergence during calculation of deviation for layer
|
|
||||||
// norm lstm.
|
|
||||||
const float kLayerNormEpsilon = 1e-8;
|
|
||||||
|
|
||||||
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
|
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
|
||||||
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
|
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
|
||||||
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
|
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
|
||||||
@ -224,9 +219,8 @@ inline void LstmStepWithAuxInput(
|
|||||||
input_gate_scratch);
|
input_gate_scratch);
|
||||||
}
|
}
|
||||||
if (is_layer_norm_lstm) {
|
if (is_layer_norm_lstm) {
|
||||||
tensor_utils::MeanStddevNormalization(input_gate_scratch,
|
tensor_utils::MeanStddevNormalization(
|
||||||
input_gate_scratch, n_cell, n_batch,
|
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
|
||||||
kLayerNormEpsilon);
|
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
|
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
|
||||||
n_batch, input_gate_scratch);
|
n_batch, input_gate_scratch);
|
||||||
@ -245,8 +239,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
}
|
}
|
||||||
if (is_layer_norm_lstm) {
|
if (is_layer_norm_lstm) {
|
||||||
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
||||||
forget_gate_scratch, n_cell, n_batch,
|
forget_gate_scratch, n_cell, n_batch);
|
||||||
kLayerNormEpsilon);
|
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
|
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
|
||||||
n_batch, forget_gate_scratch);
|
n_batch, forget_gate_scratch);
|
||||||
@ -261,7 +254,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
n_batch * n_cell, cell_state_ptr);
|
n_batch * n_cell, cell_state_ptr);
|
||||||
if (is_layer_norm_lstm) {
|
if (is_layer_norm_lstm) {
|
||||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
||||||
n_batch, kLayerNormEpsilon);
|
n_batch);
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
||||||
cell_scratch);
|
cell_scratch);
|
||||||
@ -292,8 +285,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
}
|
}
|
||||||
if (is_layer_norm_lstm) {
|
if (is_layer_norm_lstm) {
|
||||||
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
||||||
output_gate_scratch, n_cell, n_batch,
|
output_gate_scratch, n_cell, n_batch);
|
||||||
kLayerNormEpsilon);
|
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
|
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
|
||||||
n_batch, output_gate_scratch);
|
n_batch, output_gate_scratch);
|
||||||
@ -699,9 +691,8 @@ inline void LstmStepWithAuxInput(
|
|||||||
input_gate_scratch);
|
input_gate_scratch);
|
||||||
}
|
}
|
||||||
if (is_layer_norm_lstm) {
|
if (is_layer_norm_lstm) {
|
||||||
tensor_utils::MeanStddevNormalization(input_gate_scratch,
|
tensor_utils::MeanStddevNormalization(
|
||||||
input_gate_scratch, n_cell, n_batch,
|
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
|
||||||
kLayerNormEpsilon);
|
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
|
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
|
||||||
n_batch, input_gate_scratch);
|
n_batch, input_gate_scratch);
|
||||||
@ -723,8 +714,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
}
|
}
|
||||||
if (is_layer_norm_lstm) {
|
if (is_layer_norm_lstm) {
|
||||||
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
||||||
forget_gate_scratch, n_cell, n_batch,
|
forget_gate_scratch, n_cell, n_batch);
|
||||||
kLayerNormEpsilon);
|
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
|
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
|
||||||
n_batch, forget_gate_scratch);
|
n_batch, forget_gate_scratch);
|
||||||
@ -739,7 +729,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
n_batch * n_cell, cell_state_ptr);
|
n_batch * n_cell, cell_state_ptr);
|
||||||
if (is_layer_norm_lstm) {
|
if (is_layer_norm_lstm) {
|
||||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
||||||
n_batch, kLayerNormEpsilon);
|
n_batch);
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
||||||
cell_scratch);
|
cell_scratch);
|
||||||
@ -775,8 +765,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
}
|
}
|
||||||
if (is_layer_norm_lstm) {
|
if (is_layer_norm_lstm) {
|
||||||
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
||||||
output_gate_scratch, n_cell, n_batch,
|
output_gate_scratch, n_cell, n_batch);
|
||||||
kLayerNormEpsilon);
|
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
|
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
|
||||||
n_batch, output_gate_scratch);
|
n_batch, output_gate_scratch);
|
||||||
|
@ -32,8 +32,6 @@ namespace builtin {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
const float kLayerNormEpsilon = 1e-8;
|
|
||||||
|
|
||||||
inline void LstmStepWithAuxInput(
|
inline void LstmStepWithAuxInput(
|
||||||
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
|
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
|
||||||
const float* input_to_forget_weights_ptr,
|
const float* input_to_forget_weights_ptr,
|
||||||
@ -157,9 +155,8 @@ inline void LstmStepWithAuxInput(
|
|||||||
if (is_layer_norm_lstm) {
|
if (is_layer_norm_lstm) {
|
||||||
logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
|
logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
|
||||||
n_cell * n_batch);
|
n_cell * n_batch);
|
||||||
tensor_utils::MeanStddevNormalization(input_gate_scratch,
|
tensor_utils::MeanStddevNormalization(
|
||||||
input_gate_scratch, n_cell, n_batch,
|
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
|
||||||
kLayerNormEpsilon);
|
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
|
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
|
||||||
n_batch, input_gate_scratch);
|
n_batch, input_gate_scratch);
|
||||||
@ -180,8 +177,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch,
|
logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch,
|
||||||
n_cell * n_batch);
|
n_cell * n_batch);
|
||||||
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
||||||
forget_gate_scratch, n_cell, n_batch,
|
forget_gate_scratch, n_cell, n_batch);
|
||||||
kLayerNormEpsilon);
|
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
|
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
|
||||||
n_batch, forget_gate_scratch);
|
n_batch, forget_gate_scratch);
|
||||||
@ -198,7 +194,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch,
|
logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch,
|
||||||
n_cell * n_batch);
|
n_cell * n_batch);
|
||||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
||||||
n_batch, kLayerNormEpsilon);
|
n_batch);
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
||||||
cell_scratch);
|
cell_scratch);
|
||||||
@ -231,8 +227,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch,
|
logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch,
|
||||||
n_cell * n_batch);
|
n_cell * n_batch);
|
||||||
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
||||||
output_gate_scratch, n_cell, n_batch,
|
output_gate_scratch, n_cell, n_batch);
|
||||||
kLayerNormEpsilon);
|
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
|
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
|
||||||
n_batch, output_gate_scratch);
|
n_batch, output_gate_scratch);
|
||||||
|
Loading…
Reference in New Issue
Block a user