Fix gru kernel test for msan.
PiperOrigin-RevId: 247150783
This commit is contained in:
parent
59f8d54323
commit
34b0836381
@ -31,11 +31,13 @@ using ::testing::ElementsAreArray;
|
||||
|
||||
class GRUOpModel : public SingleOpModel {
|
||||
public:
|
||||
explicit GRUOpModel(const std::vector<std::vector<int>>& input_shapes,
|
||||
const TensorType& weight_type = TensorType_FLOAT32) {
|
||||
explicit GRUOpModel(int n_batch, int n_input, int n_output,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const TensorType& weight_type = TensorType_FLOAT32)
|
||||
: n_batch_(n_batch), n_input_(n_input), n_output_(n_output) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
input_state_ =
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_batch, n_output}}, true);
|
||||
gate_weight_ = AddInput(TensorType_FLOAT32);
|
||||
gate_bias_ = AddInput(TensorType_FLOAT32);
|
||||
candidate_weight_ = AddInput(TensorType_FLOAT32);
|
||||
@ -100,7 +102,8 @@ TEST(GRUTest, SimpleTest) {
|
||||
const int n_input = 2;
|
||||
const int n_output = 3;
|
||||
|
||||
GRUOpModel m({{n_time, n_batch, n_input},
|
||||
GRUOpModel m(n_batch, n_input, n_output,
|
||||
{{n_time, n_batch, n_input},
|
||||
{n_batch, n_output},
|
||||
{2 * n_output, n_input + n_output},
|
||||
{2 * n_output},
|
||||
|
Loading…
Reference in New Issue
Block a user