Refactoring before supporting quantized weights in LSTM tests.

PiperOrigin-RevId: 238459158
This commit is contained in:
A. Unique TensorFlower 2019-03-14 09:37:56 -07:00 committed by TensorFlower Gardener
parent 22a0638433
commit 6859fefa3b

View File

@ -51,7 +51,7 @@ class SingleOpModelWithNNAPI : public SingleOpModel {
}
protected:
void SetData(int index, TensorType type, std::initializer_list<float> data) {
void SetData(int index, TensorType type, const std::vector<float>& data) {
switch (type) {
case TensorType_FLOAT32:
PopulateTensor(index, data);
@ -1887,8 +1887,8 @@ static std::initializer_list<float> rnn_bias = {
class RNNOpModel : public SingleOpModelWithNNAPI {
public:
RNNOpModel(int batches, int units, int size,
const TensorType& weights = TensorType_FLOAT32,
const TensorType& recurrent_weights = TensorType_FLOAT32)
const TensorType weights = TensorType_FLOAT32,
const TensorType recurrent_weights = TensorType_FLOAT32)
: batches_(batches), units_(units), input_size_(size) {
input_ = AddInput(TensorType_FLOAT32);
weights_ = AddInput(weights);
@ -2246,11 +2246,12 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes,
const TensorType& weight_type = TensorType_FLOAT32)
const TensorType weight_type)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
n_output_(n_output) {
n_output_(n_output),
weight_type_(weight_type) {
input_ = AddInput(TensorType_FLOAT32);
if (use_cifg) {
@ -2324,47 +2325,47 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
}
void SetInputToInputWeights(const std::vector<float>& f) {
PopulateTensor(input_to_input_weights_, f);
SetData(input_to_input_weights_, weight_type_, f);
}
void SetInputToForgetWeights(const std::vector<float>& f) {
PopulateTensor(input_to_forget_weights_, f);
SetData(input_to_forget_weights_, weight_type_, f);
}
void SetInputToCellWeights(const std::vector<float>& f) {
PopulateTensor(input_to_cell_weights_, f);
SetData(input_to_cell_weights_, weight_type_, f);
}
void SetInputToOutputWeights(const std::vector<float>& f) {
PopulateTensor(input_to_output_weights_, f);
SetData(input_to_output_weights_, weight_type_, f);
}
void SetRecurrentToInputWeights(const std::vector<float>& f) {
PopulateTensor(recurrent_to_input_weights_, f);
SetData(recurrent_to_input_weights_, weight_type_, f);
}
void SetRecurrentToForgetWeights(const std::vector<float>& f) {
PopulateTensor(recurrent_to_forget_weights_, f);
SetData(recurrent_to_forget_weights_, weight_type_, f);
}
void SetRecurrentToCellWeights(const std::vector<float>& f) {
PopulateTensor(recurrent_to_cell_weights_, f);
SetData(recurrent_to_cell_weights_, weight_type_, f);
}
void SetRecurrentToOutputWeights(const std::vector<float>& f) {
PopulateTensor(recurrent_to_output_weights_, f);
SetData(recurrent_to_output_weights_, weight_type_, f);
}
void SetCellToInputWeights(const std::vector<float>& f) {
PopulateTensor(cell_to_input_weights_, f);
SetData(cell_to_input_weights_, weight_type_, f);
}
void SetCellToForgetWeights(const std::vector<float>& f) {
PopulateTensor(cell_to_forget_weights_, f);
SetData(cell_to_forget_weights_, weight_type_, f);
}
void SetCellToOutputWeights(const std::vector<float>& f) {
PopulateTensor(cell_to_output_weights_, f);
SetData(cell_to_output_weights_, weight_type_, f);
}
void SetInputGateBias(const std::vector<float>& f) {
@ -2384,7 +2385,7 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
}
void SetProjectionWeights(const std::vector<float>& f) {
PopulateTensor(projection_weights_, f);
SetData(projection_weights_, weight_type_, f);
}
void SetProjectionBias(const std::vector<float>& f) {
@ -2437,6 +2438,9 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
int n_input_;
int n_cell_;
int n_output_;
private:
const TensorType weight_type_;
};
class BaseLstmTest : public ::testing::Test {
@ -2582,7 +2586,8 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
});
},
/*weight_type=*/TensorType_FLOAT32);
lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -2685,7 +2690,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
});
},
/*weight_type=*/TensorType_FLOAT32);
lstm.SetInputToCellWeights(input_to_cell_weights_);
lstm.SetInputToForgetWeights(input_to_forget_weights_);
@ -3339,7 +3345,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
});
},
/*weight_type=*/TensorType_FLOAT32);
lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_);