Use verified modules in fusion_test.
Also disable layout assignment, because it would clear the layout of fusion computations without recomputing it. Also enable a test for the CPU backend which passes. PiperOrigin-RevId: 238196941
This commit is contained in:
parent
4bf7522056
commit
5392565fb8
@ -89,7 +89,7 @@ class FusionTest : public HloTestBase {
|
||||
}
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
|
||||
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
|
||||
|
||||
@ -141,6 +141,11 @@ class FusionTest : public HloTestBase {
|
||||
absl::Span<const float> xs);
|
||||
bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
|
||||
absl::Span<const float> xs);
|
||||
DebugOptions GetDebugOptionsForTest() override {
|
||||
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
|
||||
debug_options.add_xla_disable_hlo_passes("layout-assignment");
|
||||
return debug_options;
|
||||
}
|
||||
};
|
||||
|
||||
float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode,
|
||||
@ -234,7 +239,7 @@ XLA_TEST_F(FusionTest, Test) {
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Near(
|
||||
LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
|
||||
ExecuteNoHloPasses(std::move(hlo_module), {}), ErrorSpec(1e-4)));
|
||||
ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
|
||||
}
|
||||
|
||||
// Test whether we emit appropriate code for parameters of fusion instructions.
|
||||
@ -242,7 +247,7 @@ XLA_TEST_F(FusionTest, Parameter) {
|
||||
// Build a computation and fuse part of it so the fusion instruction has an
|
||||
// operand parameter.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}})));
|
||||
auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
@ -277,7 +282,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0});
|
||||
// Build simple fusion computation: y = x^2 (elementwise).
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
|
||||
auto two = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
|
||||
@ -301,7 +306,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
|
||||
|
||||
XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
|
||||
auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
@ -325,7 +330,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
|
||||
|
||||
XLA_TEST_F(FusionTest, ReshapeToScalar) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto single_element_array = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}})));
|
||||
auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
@ -340,7 +345,7 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
|
||||
|
||||
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
|
||||
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
@ -355,7 +360,7 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
|
||||
|
||||
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
|
||||
auto reshape1 = builder.AddInstruction(
|
||||
@ -370,7 +375,7 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
|
||||
|
||||
XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}})));
|
||||
auto reshape1 = builder.AddInstruction(
|
||||
@ -385,7 +390,7 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
|
||||
|
||||
XLA_TEST_F(FusionTest, Reshape__1by1by1) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
|
||||
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
@ -400,7 +405,7 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
|
||||
|
||||
XLA_TEST_F(FusionTest, Reshape__) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
|
||||
auto reshape1 = builder.AddInstruction(
|
||||
@ -415,7 +420,7 @@ XLA_TEST_F(FusionTest, Reshape__) {
|
||||
|
||||
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
|
||||
auto reshape1 = builder.AddInstruction(
|
||||
@ -430,7 +435,7 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
|
||||
|
||||
XLA_TEST_F(FusionTest, Transpose_2by3) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
|
||||
auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
|
||||
@ -445,7 +450,7 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
|
||||
|
||||
XLA_TEST_F(FusionTest, Transpose_3by3) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
|
||||
auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
|
||||
@ -460,7 +465,7 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
|
||||
|
||||
XLA_TEST_F(FusionTest, Reverse) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
|
||||
auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
|
||||
@ -476,7 +481,7 @@ XLA_TEST_F(FusionTest, Reverse) {
|
||||
|
||||
XLA_TEST_F(FusionTest, ReverseNegate) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
|
||||
auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
|
||||
@ -494,7 +499,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
|
||||
|
||||
XLA_TEST_F(FusionTest, BroadcastNegate) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
@ -512,7 +517,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
|
||||
|
||||
XLA_TEST_F(FusionTest, SliceNegate) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
|
||||
auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
@ -530,7 +535,7 @@ XLA_TEST_F(FusionTest, SliceNegate) {
|
||||
|
||||
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
|
||||
auto const1 = builder.AddInstruction(
|
||||
@ -552,7 +557,7 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
|
||||
|
||||
XLA_TEST_F(FusionTest, ReshapeNegate) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
|
||||
auto reshape1 = builder.AddInstruction(
|
||||
@ -570,7 +575,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
|
||||
|
||||
XLA_TEST_F(FusionTest, TransposeNegate) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}})));
|
||||
auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
|
||||
@ -598,7 +603,7 @@ std::unique_ptr<HloComputation> MakeReduceTestComputation() {
|
||||
}
|
||||
|
||||
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
@ -617,8 +622,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
|
||||
ExecuteAndTransfer(std::move(hlo_module), {})));
|
||||
}
|
||||
|
||||
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
XLA_TEST_F(FusionTest, ReduceImplicitBroadcast) {
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
@ -641,7 +646,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
|
||||
|
||||
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
|
||||
auto const1 = builder.AddInstruction(
|
||||
@ -693,7 +698,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
|
||||
// into a fusion, it should remain shared, rather than being duplicated
|
||||
// within the fusion.
|
||||
XLA_TEST_F(FusionTest, SharedConstant) {
|
||||
auto hlo_module = CreateNewUnverifiedModule();
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto const0 = builder.AddInstruction(
|
||||
|
Loading…
Reference in New Issue
Block a user