Rename a few remaining cell_ptr and cell_scale to cell_state_ptr and cell_state_scale.

PiperOrigin-RevId: 317682665
Change-Id: I89ddd7893fd1478bb7e4b9ce8873d3c5e084deb1
This commit is contained in:
Robert David 2020-06-22 10:29:31 -07:00 committed by TensorFlower Gardener
parent 75a3975ab8
commit 3a22b091cc

View File

@ -976,7 +976,7 @@ inline void LstmStepHybrid(
// Scalar values:
// quantized_cell_clip: quantized clip value for cell.
// quantized_proj_clip: quantized clip value for projection.
// cell_scale: the power of two scale for cell state.
// cell_state_scale: the power of two scale for cell state.
//
// Zero points:
// output_state_zp: zero point of output state
@ -1043,9 +1043,10 @@ inline void LstmStepInteger(
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
int16_t quantized_cell_clip, int8_t quantized_proj_clip, int32_t cell_scale,
int32_t input_variance_guard, int32_t forget_variance_guard,
int32_t cell_variance_guard, int32_t output_variance_guard,
int16_t quantized_cell_clip, int8_t quantized_proj_clip,
int32_t cell_state_scale, int32_t input_variance_guard,
int32_t forget_variance_guard, int32_t cell_variance_guard,
int32_t output_variance_guard,
const int32_t* input_to_forget_effective_bias,
const int32_t* recurrent_to_forget_effective_bias,
const int32_t* input_to_cell_effective_bias,
@ -1056,7 +1057,7 @@ inline void LstmStepInteger(
const int32_t* recurrent_to_input_effective_bias,
const int32_t* projection_effective_bias, int n_batch, int n_cell,
int n_input, int n_output, int8_t* output_state_ptr,
int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr,
int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("LstmStepInteger");
@ -1106,7 +1107,7 @@ inline void LstmStepInteger(
scratch5, forget_gate_scratch, context);
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_forget_weight_ptr, n_output, cell_ptr, n_batch,
cell_to_forget_weight_ptr, n_output, cell_state_ptr, n_batch,
effective_cell_to_forget_scale_a, effective_cell_to_forget_scale_b,
forget_gate_scratch);
}
@ -1160,7 +1161,7 @@ inline void LstmStepInteger(
scratch5, input_gate_scratch, context);
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_input_weight_ptr, n_output, cell_ptr, n_batch,
cell_to_input_weight_ptr, n_output, cell_state_ptr, n_batch,
effective_cell_to_input_scale_a, effective_cell_to_input_scale_b,
input_gate_scratch);
}
@ -1175,18 +1176,19 @@ inline void LstmStepInteger(
input_gate_scratch);
}
// New cell.
tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell, 15,
forget_gate_scratch);
// New cell state.
tensor_utils::CwiseMul(forget_gate_scratch, cell_state_ptr, n_batch, n_cell,
15, forget_gate_scratch);
tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell,
30 + cell_scale, cell_gate_scratch);
30 + cell_state_scale, cell_gate_scratch);
tensor_utils::CwiseAdd(forget_gate_scratch, cell_gate_scratch, n_batch,
n_cell, cell_ptr);
n_cell, cell_state_ptr);
if (quantized_cell_clip > 0) {
tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell);
tensor_utils::CwiseClipping(cell_state_ptr, quantized_cell_clip, n_batch,
n_cell);
}
// Ouptut gate.
@ -1202,7 +1204,7 @@ inline void LstmStepInteger(
scratch5, output_gate_scratch, context);
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_output_weight_ptr, n_output, cell_ptr, n_batch,
cell_to_output_weight_ptr, n_output, cell_state_ptr, n_batch,
effective_cell_to_output_scale_a, effective_cell_to_output_scale_b,
output_gate_scratch);
}
@ -1218,8 +1220,8 @@ inline void LstmStepInteger(
output_gate_scratch);
// Hidden.
tensor_utils::ApplyTanh(15 + cell_scale, cell_ptr, n_batch, n_cell,
input_gate_scratch);
tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state_ptr, n_batch,
n_cell, input_gate_scratch);
tensor_utils::CwiseMul(output_gate_scratch, input_gate_scratch,
effective_hidden_scale_a, effective_hidden_scale_b,
@ -1312,7 +1314,7 @@ inline void LstmStepInteger(
// Scalar values:
// quantized_cell_clip: quantized clip value for cell.
// quantized_proj_clip: quantized clip value for projection.
// cell_scale: the power of two scale for cell state.
// cell_state_scale: the power of two scale for cell state.
//
// Zero points:
// output_state_zp: zero point of output state.
@ -1385,7 +1387,7 @@ void LstmStepInteger(
const int32_t* intermediate_zp, int16_t quantized_cell_clip,
int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr,
int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
int16_t* scratch7) {
@ -1486,20 +1488,21 @@ void LstmStepInteger(
input_gate_scratch);
// New cell.
tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell,
tensor_utils::CwiseMul(forget_gate_scratch, cell_state_ptr, n_batch, n_cell,
15 + 15 - 15, scratch6);
tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell,
15 + 15 - 15, scratch7);
tensor_utils::CwiseAdd(scratch6, scratch7, n_batch, n_cell, cell_ptr);
tensor_utils::CwiseAdd(scratch6, scratch7, n_batch, n_cell, cell_state_ptr);
if (quantized_cell_clip > 0) {
tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell);
tensor_utils::CwiseClipping(cell_state_ptr, quantized_cell_clip, n_batch,
n_cell);
}
// Cell to hidden.
tensor_utils::ApplyTanhFloat(cell_ptr, n_batch, n_cell, -15,
tensor_utils::ApplyTanhFloat(cell_state_ptr, n_batch, n_cell, -15,
forget_gate_scratch);
std::vector<int16_t> hidden(n_batch * n_cell);