Fix gru kernel test for msan.

PiperOrigin-RevId: 247150783
This commit is contained in:
Renjie Liu 2019-05-07 22:09:20 -07:00 committed by TensorFlower Gardener
parent 59f8d54323
commit 34b0836381

View File

@ -31,11 +31,13 @@ using ::testing::ElementsAreArray;
class GRUOpModel : public SingleOpModel { class GRUOpModel : public SingleOpModel {
public: public:
explicit GRUOpModel(const std::vector<std::vector<int>>& input_shapes, explicit GRUOpModel(int n_batch, int n_input, int n_output,
const TensorType& weight_type = TensorType_FLOAT32) { const std::vector<std::vector<int>>& input_shapes,
const TensorType& weight_type = TensorType_FLOAT32)
: n_batch_(n_batch), n_input_(n_input), n_output_(n_output) {
input_ = AddInput(TensorType_FLOAT32); input_ = AddInput(TensorType_FLOAT32);
input_state_ = input_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true); AddInput(TensorData{TensorType_FLOAT32, {n_batch, n_output}}, true);
gate_weight_ = AddInput(TensorType_FLOAT32); gate_weight_ = AddInput(TensorType_FLOAT32);
gate_bias_ = AddInput(TensorType_FLOAT32); gate_bias_ = AddInput(TensorType_FLOAT32);
candidate_weight_ = AddInput(TensorType_FLOAT32); candidate_weight_ = AddInput(TensorType_FLOAT32);
@ -100,7 +102,8 @@ TEST(GRUTest, SimpleTest) {
const int n_input = 2; const int n_input = 2;
const int n_output = 3; const int n_output = 3;
GRUOpModel m({{n_time, n_batch, n_input}, GRUOpModel m(n_batch, n_input, n_output,
{{n_time, n_batch, n_input},
{n_batch, n_output}, {n_batch, n_output},
{2 * n_output, n_input + n_output}, {2 * n_output, n_input + n_output},
{2 * n_output}, {2 * n_output},