Refactoring before supporting quantized weights in LSTM tests.
PiperOrigin-RevId: 238459158
This commit is contained in:
parent
22a0638433
commit
6859fefa3b
@ -51,7 +51,7 @@ class SingleOpModelWithNNAPI : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void SetData(int index, TensorType type, std::initializer_list<float> data) {
|
void SetData(int index, TensorType type, const std::vector<float>& data) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case TensorType_FLOAT32:
|
case TensorType_FLOAT32:
|
||||||
PopulateTensor(index, data);
|
PopulateTensor(index, data);
|
||||||
@ -1887,8 +1887,8 @@ static std::initializer_list<float> rnn_bias = {
|
|||||||
class RNNOpModel : public SingleOpModelWithNNAPI {
|
class RNNOpModel : public SingleOpModelWithNNAPI {
|
||||||
public:
|
public:
|
||||||
RNNOpModel(int batches, int units, int size,
|
RNNOpModel(int batches, int units, int size,
|
||||||
const TensorType& weights = TensorType_FLOAT32,
|
const TensorType weights = TensorType_FLOAT32,
|
||||||
const TensorType& recurrent_weights = TensorType_FLOAT32)
|
const TensorType recurrent_weights = TensorType_FLOAT32)
|
||||||
: batches_(batches), units_(units), input_size_(size) {
|
: batches_(batches), units_(units), input_size_(size) {
|
||||||
input_ = AddInput(TensorType_FLOAT32);
|
input_ = AddInput(TensorType_FLOAT32);
|
||||||
weights_ = AddInput(weights);
|
weights_ = AddInput(weights);
|
||||||
@ -2246,11 +2246,12 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
|
|||||||
bool use_peephole, bool use_projection_weights,
|
bool use_peephole, bool use_projection_weights,
|
||||||
bool use_projection_bias, float cell_clip, float proj_clip,
|
bool use_projection_bias, float cell_clip, float proj_clip,
|
||||||
const std::vector<std::vector<int>>& input_shapes,
|
const std::vector<std::vector<int>>& input_shapes,
|
||||||
const TensorType& weight_type = TensorType_FLOAT32)
|
const TensorType weight_type)
|
||||||
: n_batch_(n_batch),
|
: n_batch_(n_batch),
|
||||||
n_input_(n_input),
|
n_input_(n_input),
|
||||||
n_cell_(n_cell),
|
n_cell_(n_cell),
|
||||||
n_output_(n_output) {
|
n_output_(n_output),
|
||||||
|
weight_type_(weight_type) {
|
||||||
input_ = AddInput(TensorType_FLOAT32);
|
input_ = AddInput(TensorType_FLOAT32);
|
||||||
|
|
||||||
if (use_cifg) {
|
if (use_cifg) {
|
||||||
@ -2324,47 +2325,47 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SetInputToInputWeights(const std::vector<float>& f) {
|
void SetInputToInputWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(input_to_input_weights_, f);
|
SetData(input_to_input_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetInputToForgetWeights(const std::vector<float>& f) {
|
void SetInputToForgetWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(input_to_forget_weights_, f);
|
SetData(input_to_forget_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetInputToCellWeights(const std::vector<float>& f) {
|
void SetInputToCellWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(input_to_cell_weights_, f);
|
SetData(input_to_cell_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetInputToOutputWeights(const std::vector<float>& f) {
|
void SetInputToOutputWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(input_to_output_weights_, f);
|
SetData(input_to_output_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetRecurrentToInputWeights(const std::vector<float>& f) {
|
void SetRecurrentToInputWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(recurrent_to_input_weights_, f);
|
SetData(recurrent_to_input_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetRecurrentToForgetWeights(const std::vector<float>& f) {
|
void SetRecurrentToForgetWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(recurrent_to_forget_weights_, f);
|
SetData(recurrent_to_forget_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetRecurrentToCellWeights(const std::vector<float>& f) {
|
void SetRecurrentToCellWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(recurrent_to_cell_weights_, f);
|
SetData(recurrent_to_cell_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetRecurrentToOutputWeights(const std::vector<float>& f) {
|
void SetRecurrentToOutputWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(recurrent_to_output_weights_, f);
|
SetData(recurrent_to_output_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetCellToInputWeights(const std::vector<float>& f) {
|
void SetCellToInputWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(cell_to_input_weights_, f);
|
SetData(cell_to_input_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetCellToForgetWeights(const std::vector<float>& f) {
|
void SetCellToForgetWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(cell_to_forget_weights_, f);
|
SetData(cell_to_forget_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetCellToOutputWeights(const std::vector<float>& f) {
|
void SetCellToOutputWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(cell_to_output_weights_, f);
|
SetData(cell_to_output_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetInputGateBias(const std::vector<float>& f) {
|
void SetInputGateBias(const std::vector<float>& f) {
|
||||||
@ -2384,7 +2385,7 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SetProjectionWeights(const std::vector<float>& f) {
|
void SetProjectionWeights(const std::vector<float>& f) {
|
||||||
PopulateTensor(projection_weights_, f);
|
SetData(projection_weights_, weight_type_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetProjectionBias(const std::vector<float>& f) {
|
void SetProjectionBias(const std::vector<float>& f) {
|
||||||
@ -2437,6 +2438,9 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
|
|||||||
int n_input_;
|
int n_input_;
|
||||||
int n_cell_;
|
int n_cell_;
|
||||||
int n_output_;
|
int n_output_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const TensorType weight_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BaseLstmTest : public ::testing::Test {
|
class BaseLstmTest : public ::testing::Test {
|
||||||
@ -2582,7 +2586,8 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
|
|
||||||
{0, 0}, // projection_weight tensor
|
{0, 0}, // projection_weight tensor
|
||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
});
|
},
|
||||||
|
/*weight_type=*/TensorType_FLOAT32);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
@ -2685,7 +2690,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
|
|
||||||
{0, 0}, // projection_weight tensor
|
{0, 0}, // projection_weight tensor
|
||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
});
|
},
|
||||||
|
/*weight_type=*/TensorType_FLOAT32);
|
||||||
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||||
@ -3339,7 +3345,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
|
|
||||||
{n_output, n_cell}, // projection_weight tensor
|
{n_output, n_cell}, // projection_weight tensor
|
||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
});
|
},
|
||||||
|
/*weight_type=*/TensorType_FLOAT32);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
|
Loading…
Reference in New Issue
Block a user