Rename a few remaining cell_ptr and cell_scale to cell_state_ptr and cell_state_scale.
PiperOrigin-RevId: 317682665 Change-Id: I89ddd7893fd1478bb7e4b9ce8873d3c5e084deb1
This commit is contained in:
parent
75a3975ab8
commit
3a22b091cc
@ -976,7 +976,7 @@ inline void LstmStepHybrid(
|
||||
// Scalar values:
|
||||
// quantized_cell_clip: quantized clip value for cell.
|
||||
// quantized_proj_clip: quantized clip value for projection.
|
||||
// cell_scale: the power of two scale for cell state.
|
||||
// cell_state_scale: the power of two scale for cell state.
|
||||
//
|
||||
// Zero points:
|
||||
// output_state_zp: zero point of output state
|
||||
@ -1043,9 +1043,10 @@ inline void LstmStepInteger(
|
||||
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
|
||||
const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
|
||||
const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
|
||||
int16_t quantized_cell_clip, int8_t quantized_proj_clip, int32_t cell_scale,
|
||||
int32_t input_variance_guard, int32_t forget_variance_guard,
|
||||
int32_t cell_variance_guard, int32_t output_variance_guard,
|
||||
int16_t quantized_cell_clip, int8_t quantized_proj_clip,
|
||||
int32_t cell_state_scale, int32_t input_variance_guard,
|
||||
int32_t forget_variance_guard, int32_t cell_variance_guard,
|
||||
int32_t output_variance_guard,
|
||||
const int32_t* input_to_forget_effective_bias,
|
||||
const int32_t* recurrent_to_forget_effective_bias,
|
||||
const int32_t* input_to_cell_effective_bias,
|
||||
@ -1056,7 +1057,7 @@ inline void LstmStepInteger(
|
||||
const int32_t* recurrent_to_input_effective_bias,
|
||||
const int32_t* projection_effective_bias, int n_batch, int n_cell,
|
||||
int n_input, int n_output, int8_t* output_state_ptr,
|
||||
int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr,
|
||||
int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
|
||||
int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
|
||||
int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) {
|
||||
ruy::profiler::ScopeLabel label("LstmStepInteger");
|
||||
@ -1106,7 +1107,7 @@ inline void LstmStepInteger(
|
||||
scratch5, forget_gate_scratch, context);
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_forget_weight_ptr, n_output, cell_ptr, n_batch,
|
||||
cell_to_forget_weight_ptr, n_output, cell_state_ptr, n_batch,
|
||||
effective_cell_to_forget_scale_a, effective_cell_to_forget_scale_b,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
@ -1160,7 +1161,7 @@ inline void LstmStepInteger(
|
||||
scratch5, input_gate_scratch, context);
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_input_weight_ptr, n_output, cell_ptr, n_batch,
|
||||
cell_to_input_weight_ptr, n_output, cell_state_ptr, n_batch,
|
||||
effective_cell_to_input_scale_a, effective_cell_to_input_scale_b,
|
||||
input_gate_scratch);
|
||||
}
|
||||
@ -1175,18 +1176,19 @@ inline void LstmStepInteger(
|
||||
input_gate_scratch);
|
||||
}
|
||||
|
||||
// New cell.
|
||||
tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell, 15,
|
||||
forget_gate_scratch);
|
||||
// New cell state.
|
||||
tensor_utils::CwiseMul(forget_gate_scratch, cell_state_ptr, n_batch, n_cell,
|
||||
15, forget_gate_scratch);
|
||||
|
||||
tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell,
|
||||
30 + cell_scale, cell_gate_scratch);
|
||||
30 + cell_state_scale, cell_gate_scratch);
|
||||
|
||||
tensor_utils::CwiseAdd(forget_gate_scratch, cell_gate_scratch, n_batch,
|
||||
n_cell, cell_ptr);
|
||||
n_cell, cell_state_ptr);
|
||||
|
||||
if (quantized_cell_clip > 0) {
|
||||
tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell);
|
||||
tensor_utils::CwiseClipping(cell_state_ptr, quantized_cell_clip, n_batch,
|
||||
n_cell);
|
||||
}
|
||||
|
||||
// Ouptut gate.
|
||||
@ -1202,7 +1204,7 @@ inline void LstmStepInteger(
|
||||
scratch5, output_gate_scratch, context);
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_output_weight_ptr, n_output, cell_ptr, n_batch,
|
||||
cell_to_output_weight_ptr, n_output, cell_state_ptr, n_batch,
|
||||
effective_cell_to_output_scale_a, effective_cell_to_output_scale_b,
|
||||
output_gate_scratch);
|
||||
}
|
||||
@ -1218,8 +1220,8 @@ inline void LstmStepInteger(
|
||||
output_gate_scratch);
|
||||
|
||||
// Hidden.
|
||||
tensor_utils::ApplyTanh(15 + cell_scale, cell_ptr, n_batch, n_cell,
|
||||
input_gate_scratch);
|
||||
tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state_ptr, n_batch,
|
||||
n_cell, input_gate_scratch);
|
||||
|
||||
tensor_utils::CwiseMul(output_gate_scratch, input_gate_scratch,
|
||||
effective_hidden_scale_a, effective_hidden_scale_b,
|
||||
@ -1312,7 +1314,7 @@ inline void LstmStepInteger(
|
||||
// Scalar values:
|
||||
// quantized_cell_clip: quantized clip value for cell.
|
||||
// quantized_proj_clip: quantized clip value for projection.
|
||||
// cell_scale: the power of two scale for cell state.
|
||||
// cell_state_scale: the power of two scale for cell state.
|
||||
//
|
||||
// Zero points:
|
||||
// output_state_zp: zero point of output state.
|
||||
@ -1385,7 +1387,7 @@ void LstmStepInteger(
|
||||
const int32_t* intermediate_zp, int16_t quantized_cell_clip,
|
||||
int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
|
||||
int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
|
||||
int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr,
|
||||
int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
|
||||
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
|
||||
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
|
||||
int16_t* scratch7) {
|
||||
@ -1486,20 +1488,21 @@ void LstmStepInteger(
|
||||
input_gate_scratch);
|
||||
|
||||
// New cell.
|
||||
tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell,
|
||||
tensor_utils::CwiseMul(forget_gate_scratch, cell_state_ptr, n_batch, n_cell,
|
||||
15 + 15 - 15, scratch6);
|
||||
|
||||
tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell,
|
||||
15 + 15 - 15, scratch7);
|
||||
|
||||
tensor_utils::CwiseAdd(scratch6, scratch7, n_batch, n_cell, cell_ptr);
|
||||
tensor_utils::CwiseAdd(scratch6, scratch7, n_batch, n_cell, cell_state_ptr);
|
||||
|
||||
if (quantized_cell_clip > 0) {
|
||||
tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell);
|
||||
tensor_utils::CwiseClipping(cell_state_ptr, quantized_cell_clip, n_batch,
|
||||
n_cell);
|
||||
}
|
||||
|
||||
// Cell to hidden.
|
||||
tensor_utils::ApplyTanhFloat(cell_ptr, n_batch, n_cell, -15,
|
||||
tensor_utils::ApplyTanhFloat(cell_state_ptr, n_batch, n_cell, -15,
|
||||
forget_gate_scratch);
|
||||
|
||||
std::vector<int16_t> hidden(n_batch * n_cell);
|
||||
|
Loading…
Reference in New Issue
Block a user