Create VerifiedModule in tests with HloText.
This makes sure that the verifier is run in addition to the parser. The parser may allow something which is not actually valid HLO. Fix the bugs in tests that were found due to this change. PiperOrigin-RevId: 230885173
This commit is contained in:
parent
8a9c9c8f8f
commit
3296f42b0f
@ -2227,9 +2227,8 @@ TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
|
||||
ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
|
@ -1855,8 +1855,7 @@ ENTRY %TokensShouldNotBeCopied () -> s32[] {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
HloRunner::CreateModuleFromString(
|
||||
module_string, GetDebugOptionsForTest()));
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
InsertCopies(module.get());
|
||||
|
||||
// There should be no copies added because tokens should not be copied.
|
||||
@ -2119,8 +2118,7 @@ ENTRY TestComputation {
|
||||
ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
|
||||
}
|
||||
)";
|
||||
auto module_or_status =
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
|
||||
auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
|
||||
auto module = module_or_status.ConsumeValueOrDie();
|
||||
InsertCopies(module.get());
|
||||
}
|
||||
@ -2220,8 +2218,7 @@ ENTRY TestComputation {
|
||||
ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
|
||||
}
|
||||
)";
|
||||
auto module_or_status =
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
|
||||
auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
|
||||
auto module = module_or_status.ConsumeValueOrDie();
|
||||
InsertCopies(module.get());
|
||||
}
|
||||
@ -2238,7 +2235,7 @@ cond.inner {
|
||||
|
||||
body.inner {
|
||||
param.body.inner = pred[] parameter(0)
|
||||
ROOT neg = pred[] negate(param.body.inner)
|
||||
ROOT not = pred[] not(param.body.inner)
|
||||
}
|
||||
|
||||
cond.outer {
|
||||
@ -2255,9 +2252,8 @@ ENTRY TestComputation {
|
||||
ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
|
||||
// There should only be a single copy inserted, and it's in the entry
|
||||
|
@ -73,15 +73,14 @@ ENTRY TestComputation {
|
||||
abs = f32[] abs(arg)
|
||||
add = f32[] add(arg, gte)
|
||||
broadcast = f32[42] broadcast(add), dimensions={}
|
||||
slice = f32[0] slice(broadcast), slice={[1:2]}
|
||||
slice = f32[1] slice(broadcast), slice={[1:2]}
|
||||
copy = f32[] copy(arg)
|
||||
eq = pred[] equal-to(arg, gte)
|
||||
neg = f32[] negate(arg)
|
||||
ROOT convert = f64[] convert(f32[] arg)
|
||||
})";
|
||||
std::unique_ptr<HloModule> module =
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())
|
||||
.ConsumeValueOrDie();
|
||||
ParseAndReturnVerifiedModule(hlo_string).ConsumeValueOrDie();
|
||||
ElementwiseTestVisitor visitor;
|
||||
TF_EXPECT_OK(module->entry_computation()->Accept(&visitor));
|
||||
}
|
||||
|
@ -28,15 +28,7 @@ using ::testing::Eq;
|
||||
using ::testing::Not;
|
||||
using ::testing::ResultOf;
|
||||
|
||||
class HloElementTypeConverterTest : public HloTestBase {
|
||||
public:
|
||||
std::unique_ptr<HloModule> CreateModuleFromHloString(
|
||||
const string& hlo_string) {
|
||||
return HloRunner::CreateModuleFromString(hlo_string,
|
||||
GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
}
|
||||
};
|
||||
using HloElementTypeConverterTest = HloTestBase;
|
||||
|
||||
TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) {
|
||||
const string& hlo_string = R"(
|
||||
@ -47,7 +39,7 @@ TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) {
|
||||
custom_call_target="foo"
|
||||
}
|
||||
)";
|
||||
auto module = CreateModuleFromHloString(hlo_string);
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
HloElementTypeConverter type_converter(BF16, F32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||
EXPECT_FALSE(converted);
|
||||
@ -63,7 +55,7 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) {
|
||||
outfeed = token[] outfeed(infeed.data, token0)
|
||||
}
|
||||
)";
|
||||
auto module = CreateModuleFromHloString(hlo_string);
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
HloElementTypeConverter type_converter(BF16, F32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||
EXPECT_FALSE(converted);
|
||||
@ -73,17 +65,16 @@ TEST_F(HloElementTypeConverterTest, OperationsInNestedTuplesConverted) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule NestedTuples
|
||||
ENTRY NestedTuples.v5 {
|
||||
constant.4 = bf16[] constant(42)
|
||||
constant.2 = f32[2]{0} constant({1, 2})
|
||||
constant.3 = bf16[] constant(42)
|
||||
add = bf16[] add(constant.2, constant.3)
|
||||
tuple = (f32[2]{0}, bf16[]) tuple(constant.2, add)
|
||||
constant.3 = bf16[2]{0} constant({42, 42})
|
||||
add = bf16[2]{0} add(constant.2, constant.3)
|
||||
tuple = (f32[2]{0}, bf16[2]{0}) tuple(constant.2, add)
|
||||
constant.5 = bf16[2]{0} constant({22, 44})
|
||||
ROOT tuple.1 = ((f32[2]{0}, bf16[]), bf16[2]{0}) tuple(tuple, constant.5)
|
||||
ROOT tuple.1 = ((f32[2]{0}, bf16[2]{0}), bf16[2]{0}) tuple(tuple, constant.5)
|
||||
}
|
||||
)";
|
||||
|
||||
auto module = CreateModuleFromHloString(hlo_string);
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
HloElementTypeConverter type_converter(BF16, F32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||
EXPECT_TRUE(converted);
|
||||
@ -111,7 +102,7 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) {
|
||||
}
|
||||
)";
|
||||
|
||||
auto module = CreateModuleFromHloString(hlo_string);
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
HloElementTypeConverter type_converter(BF16, F32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||
EXPECT_TRUE(converted);
|
||||
@ -135,7 +126,7 @@ ENTRY main {
|
||||
ROOT rng = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform
|
||||
}
|
||||
)";
|
||||
auto module = CreateModuleFromHloString(hlo_string);
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
HloElementTypeConverter type_converter(BF16, F32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||
EXPECT_TRUE(converted);
|
||||
@ -161,7 +152,7 @@ ENTRY main {
|
||||
ROOT rng1 = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), control-predecessors={%rng0}, distribution=rng_uniform
|
||||
}
|
||||
)";
|
||||
auto module = CreateModuleFromHloString(hlo_string);
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
|
||||
HloElementTypeConverter type_converter(BF16, F32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||
|
@ -224,9 +224,7 @@ XLA_TEST_F(ConstantsHloTest, DISABLED_ON_GPU(BitcastOfConstant)) {
|
||||
ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func
|
||||
}
|
||||
)";
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param = LiteralUtil::CreateR0<int32>(1);
|
||||
auto result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(param, result));
|
||||
|
@ -208,9 +208,7 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
|
||||
ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param = LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)),
|
||||
@ -241,9 +239,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
|
||||
const = f32[4] constant({0, 0, 0, 0})
|
||||
ROOT select = f32[4] select(gte0, gte1, const)
|
||||
})";
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
|
||||
@ -273,9 +269,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
|
||||
p1 = f32[3] parameter(0)
|
||||
ROOT map = f32[3] map(p1), to_apply=map_computation
|
||||
})";
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
|
||||
@ -315,9 +309,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
|
||||
calls=fused_reduce
|
||||
})");
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
@ -346,9 +338,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
|
||||
calls=fused_reduce
|
||||
})");
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
@ -378,9 +368,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
||||
ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput,
|
||||
calls=fused_reduce
|
||||
})");
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
@ -410,9 +398,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
||||
ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p),
|
||||
kind=kInput, calls=fused_reduce
|
||||
})");
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
@ -443,9 +429,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p),
|
||||
kind=kInput, calls=fused_reduce
|
||||
})");
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
@ -478,9 +462,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
||||
ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p),
|
||||
kind=kInput, calls=fused_reduce
|
||||
})");
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
@ -513,9 +495,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput,
|
||||
calls=fused_reduce
|
||||
})");
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||
auto init1 = LiteralUtil::CreateR0<float>(5);
|
||||
@ -549,9 +529,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p),
|
||||
kind=kInput, calls=fused_reduce
|
||||
})");
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param = LiteralUtil::CreateR3<Eigen::half>(
|
||||
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
|
||||
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
|
||||
|
@ -525,9 +525,7 @@ XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
|
||||
ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy)
|
||||
}
|
||||
)";
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param =
|
||||
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
|
||||
auto result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||
@ -559,9 +557,7 @@ XLA_TEST_F(TupleHloTest,
|
||||
ROOT outfeed = token[] outfeed(tuple, token0)
|
||||
}
|
||||
)";
|
||||
auto module =
|
||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
||||
.ValueOrDie();
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param0 = LiteralUtil::CreateR1<float>({1, 2});
|
||||
auto param1 = LiteralUtil::CreateR1<float>({2, 3});
|
||||
auto param4 = LiteralUtil::CreateR0<bool>(false);
|
||||
|
@ -36,9 +36,8 @@ ENTRY %entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
{
|
||||
auto extracted_module =
|
||||
@ -75,9 +74,8 @@ ENTRY %entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
{
|
||||
auto extracted_module =
|
||||
@ -120,9 +118,8 @@ ENTRY %entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
{
|
||||
auto extracted_module =
|
||||
|
Loading…
Reference in New Issue
Block a user