Refactoring before supporting quantized weights in LSTM tests.

PiperOrigin-RevId: 238459158
This commit is contained in:
A. Unique TensorFlower 2019-03-14 09:37:56 -07:00 committed by TensorFlower Gardener
parent 22a0638433
commit 6859fefa3b

View File

@ -51,7 +51,7 @@ class SingleOpModelWithNNAPI : public SingleOpModel {
} }
protected: protected:
void SetData(int index, TensorType type, std::initializer_list<float> data) { void SetData(int index, TensorType type, const std::vector<float>& data) {
switch (type) { switch (type) {
case TensorType_FLOAT32: case TensorType_FLOAT32:
PopulateTensor(index, data); PopulateTensor(index, data);
@ -1887,8 +1887,8 @@ static std::initializer_list<float> rnn_bias = {
class RNNOpModel : public SingleOpModelWithNNAPI { class RNNOpModel : public SingleOpModelWithNNAPI {
public: public:
RNNOpModel(int batches, int units, int size, RNNOpModel(int batches, int units, int size,
const TensorType& weights = TensorType_FLOAT32, const TensorType weights = TensorType_FLOAT32,
const TensorType& recurrent_weights = TensorType_FLOAT32) const TensorType recurrent_weights = TensorType_FLOAT32)
: batches_(batches), units_(units), input_size_(size) { : batches_(batches), units_(units), input_size_(size) {
input_ = AddInput(TensorType_FLOAT32); input_ = AddInput(TensorType_FLOAT32);
weights_ = AddInput(weights); weights_ = AddInput(weights);
@ -2246,11 +2246,12 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
bool use_peephole, bool use_projection_weights, bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip, bool use_projection_bias, float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes, const std::vector<std::vector<int>>& input_shapes,
const TensorType& weight_type = TensorType_FLOAT32) const TensorType weight_type)
: n_batch_(n_batch), : n_batch_(n_batch),
n_input_(n_input), n_input_(n_input),
n_cell_(n_cell), n_cell_(n_cell),
n_output_(n_output) { n_output_(n_output),
weight_type_(weight_type) {
input_ = AddInput(TensorType_FLOAT32); input_ = AddInput(TensorType_FLOAT32);
if (use_cifg) { if (use_cifg) {
@ -2324,47 +2325,47 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
} }
void SetInputToInputWeights(const std::vector<float>& f) { void SetInputToInputWeights(const std::vector<float>& f) {
PopulateTensor(input_to_input_weights_, f); SetData(input_to_input_weights_, weight_type_, f);
} }
void SetInputToForgetWeights(const std::vector<float>& f) { void SetInputToForgetWeights(const std::vector<float>& f) {
PopulateTensor(input_to_forget_weights_, f); SetData(input_to_forget_weights_, weight_type_, f);
} }
void SetInputToCellWeights(const std::vector<float>& f) { void SetInputToCellWeights(const std::vector<float>& f) {
PopulateTensor(input_to_cell_weights_, f); SetData(input_to_cell_weights_, weight_type_, f);
} }
void SetInputToOutputWeights(const std::vector<float>& f) { void SetInputToOutputWeights(const std::vector<float>& f) {
PopulateTensor(input_to_output_weights_, f); SetData(input_to_output_weights_, weight_type_, f);
} }
void SetRecurrentToInputWeights(const std::vector<float>& f) { void SetRecurrentToInputWeights(const std::vector<float>& f) {
PopulateTensor(recurrent_to_input_weights_, f); SetData(recurrent_to_input_weights_, weight_type_, f);
} }
void SetRecurrentToForgetWeights(const std::vector<float>& f) { void SetRecurrentToForgetWeights(const std::vector<float>& f) {
PopulateTensor(recurrent_to_forget_weights_, f); SetData(recurrent_to_forget_weights_, weight_type_, f);
} }
void SetRecurrentToCellWeights(const std::vector<float>& f) { void SetRecurrentToCellWeights(const std::vector<float>& f) {
PopulateTensor(recurrent_to_cell_weights_, f); SetData(recurrent_to_cell_weights_, weight_type_, f);
} }
void SetRecurrentToOutputWeights(const std::vector<float>& f) { void SetRecurrentToOutputWeights(const std::vector<float>& f) {
PopulateTensor(recurrent_to_output_weights_, f); SetData(recurrent_to_output_weights_, weight_type_, f);
} }
void SetCellToInputWeights(const std::vector<float>& f) { void SetCellToInputWeights(const std::vector<float>& f) {
PopulateTensor(cell_to_input_weights_, f); SetData(cell_to_input_weights_, weight_type_, f);
} }
void SetCellToForgetWeights(const std::vector<float>& f) { void SetCellToForgetWeights(const std::vector<float>& f) {
PopulateTensor(cell_to_forget_weights_, f); SetData(cell_to_forget_weights_, weight_type_, f);
} }
void SetCellToOutputWeights(const std::vector<float>& f) { void SetCellToOutputWeights(const std::vector<float>& f) {
PopulateTensor(cell_to_output_weights_, f); SetData(cell_to_output_weights_, weight_type_, f);
} }
void SetInputGateBias(const std::vector<float>& f) { void SetInputGateBias(const std::vector<float>& f) {
@ -2384,7 +2385,7 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
} }
void SetProjectionWeights(const std::vector<float>& f) { void SetProjectionWeights(const std::vector<float>& f) {
PopulateTensor(projection_weights_, f); SetData(projection_weights_, weight_type_, f);
} }
void SetProjectionBias(const std::vector<float>& f) { void SetProjectionBias(const std::vector<float>& f) {
@ -2437,6 +2438,9 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
int n_input_; int n_input_;
int n_cell_; int n_cell_;
int n_output_; int n_output_;
private:
const TensorType weight_type_;
}; };
class BaseLstmTest : public ::testing::Test { class BaseLstmTest : public ::testing::Test {
@ -2582,7 +2586,8 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor {0, 0}, // projection_weight tensor
{0}, // projection_bias tensor {0}, // projection_bias tensor
}); },
/*weight_type=*/TensorType_FLOAT32);
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -2685,7 +2690,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor {0, 0}, // projection_weight tensor
{0}, // projection_bias tensor {0}, // projection_bias tensor
}); },
/*weight_type=*/TensorType_FLOAT32);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
lstm.SetInputToForgetWeights(input_to_forget_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_);
@ -3339,7 +3345,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor {n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor {0}, // projection_bias tensor
}); },
/*weight_type=*/TensorType_FLOAT32);
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);