Refactoring before supporting quantized weights in LSTM tests.
PiperOrigin-RevId: 238459158
This commit is contained in:
parent
22a0638433
commit
6859fefa3b
@ -51,7 +51,7 @@ class SingleOpModelWithNNAPI : public SingleOpModel {
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetData(int index, TensorType type, std::initializer_list<float> data) {
|
||||
void SetData(int index, TensorType type, const std::vector<float>& data) {
|
||||
switch (type) {
|
||||
case TensorType_FLOAT32:
|
||||
PopulateTensor(index, data);
|
||||
@ -1887,8 +1887,8 @@ static std::initializer_list<float> rnn_bias = {
|
||||
class RNNOpModel : public SingleOpModelWithNNAPI {
|
||||
public:
|
||||
RNNOpModel(int batches, int units, int size,
|
||||
const TensorType& weights = TensorType_FLOAT32,
|
||||
const TensorType& recurrent_weights = TensorType_FLOAT32)
|
||||
const TensorType weights = TensorType_FLOAT32,
|
||||
const TensorType recurrent_weights = TensorType_FLOAT32)
|
||||
: batches_(batches), units_(units), input_size_(size) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
weights_ = AddInput(weights);
|
||||
@ -2246,11 +2246,12 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
|
||||
bool use_peephole, bool use_projection_weights,
|
||||
bool use_projection_bias, float cell_clip, float proj_clip,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const TensorType& weight_type = TensorType_FLOAT32)
|
||||
const TensorType weight_type)
|
||||
: n_batch_(n_batch),
|
||||
n_input_(n_input),
|
||||
n_cell_(n_cell),
|
||||
n_output_(n_output) {
|
||||
n_output_(n_output),
|
||||
weight_type_(weight_type) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
|
||||
if (use_cifg) {
|
||||
@ -2324,47 +2325,47 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
|
||||
}
|
||||
|
||||
void SetInputToInputWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(input_to_input_weights_, f);
|
||||
SetData(input_to_input_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetInputToForgetWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(input_to_forget_weights_, f);
|
||||
SetData(input_to_forget_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetInputToCellWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(input_to_cell_weights_, f);
|
||||
SetData(input_to_cell_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetInputToOutputWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(input_to_output_weights_, f);
|
||||
SetData(input_to_output_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetRecurrentToInputWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(recurrent_to_input_weights_, f);
|
||||
SetData(recurrent_to_input_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetRecurrentToForgetWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(recurrent_to_forget_weights_, f);
|
||||
SetData(recurrent_to_forget_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetRecurrentToCellWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(recurrent_to_cell_weights_, f);
|
||||
SetData(recurrent_to_cell_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetRecurrentToOutputWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(recurrent_to_output_weights_, f);
|
||||
SetData(recurrent_to_output_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetCellToInputWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(cell_to_input_weights_, f);
|
||||
SetData(cell_to_input_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetCellToForgetWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(cell_to_forget_weights_, f);
|
||||
SetData(cell_to_forget_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetCellToOutputWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(cell_to_output_weights_, f);
|
||||
SetData(cell_to_output_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetInputGateBias(const std::vector<float>& f) {
|
||||
@ -2384,7 +2385,7 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
|
||||
}
|
||||
|
||||
void SetProjectionWeights(const std::vector<float>& f) {
|
||||
PopulateTensor(projection_weights_, f);
|
||||
SetData(projection_weights_, weight_type_, f);
|
||||
}
|
||||
|
||||
void SetProjectionBias(const std::vector<float>& f) {
|
||||
@ -2437,6 +2438,9 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
|
||||
int n_input_;
|
||||
int n_cell_;
|
||||
int n_output_;
|
||||
|
||||
private:
|
||||
const TensorType weight_type_;
|
||||
};
|
||||
|
||||
class BaseLstmTest : public ::testing::Test {
|
||||
@ -2582,7 +2586,8 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
||||
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
});
|
||||
},
|
||||
/*weight_type=*/TensorType_FLOAT32);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
@ -2685,7 +2690,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
||||
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
});
|
||||
},
|
||||
/*weight_type=*/TensorType_FLOAT32);
|
||||
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
@ -3339,7 +3345,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
|
||||
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
});
|
||||
},
|
||||
/*weight_type=*/TensorType_FLOAT32);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
|
Loading…
Reference in New Issue
Block a user