Create VerifiedModule in tests with HloText.

This makes sure that the verifier is run in addition to the parser.
The parser may allow something which is not actually valid HLO.
Fix the bugs in tests that were found due to this change.

PiperOrigin-RevId: 230885173
This commit is contained in:
Adrian Kuegel 2019-01-25 05:03:17 -08:00 committed by TensorFlower Gardener
parent 8a9c9c8f8f
commit 3296f42b0f
8 changed files with 41 additions and 87 deletions

View File

@ -2227,9 +2227,8 @@ TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());

View File

@ -1855,8 +1855,7 @@ ENTRY %TokensShouldNotBeCopied () -> s32[] {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(
module_string, GetDebugOptionsForTest()));
ParseAndReturnVerifiedModule(module_string));
InsertCopies(module.get());
// There should be no copies added because tokens should not be copied.
@ -2119,8 +2118,7 @@ ENTRY TestComputation {
ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
}
)";
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
auto module = module_or_status.ConsumeValueOrDie();
InsertCopies(module.get());
}
@ -2220,8 +2218,7 @@ ENTRY TestComputation {
ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
}
)";
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
auto module = module_or_status.ConsumeValueOrDie();
InsertCopies(module.get());
}
@ -2238,7 +2235,7 @@ cond.inner {
body.inner {
param.body.inner = pred[] parameter(0)
ROOT neg = pred[] negate(param.body.inner)
ROOT not = pred[] not(param.body.inner)
}
cond.outer {
@ -2255,9 +2252,8 @@ ENTRY TestComputation {
ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
// There should only be a single copy inserted, and it's in the entry

View File

@ -73,15 +73,14 @@ ENTRY TestComputation {
abs = f32[] abs(arg)
add = f32[] add(arg, gte)
broadcast = f32[42] broadcast(add), dimensions={}
slice = f32[0] slice(broadcast), slice={[1:2]}
slice = f32[1] slice(broadcast), slice={[1:2]}
copy = f32[] copy(arg)
eq = pred[] equal-to(arg, gte)
neg = f32[] negate(arg)
ROOT convert = f64[] convert(f32[] arg)
})";
std::unique_ptr<HloModule> module =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())
.ConsumeValueOrDie();
ParseAndReturnVerifiedModule(hlo_string).ConsumeValueOrDie();
ElementwiseTestVisitor visitor;
TF_EXPECT_OK(module->entry_computation()->Accept(&visitor));
}

View File

@ -28,15 +28,7 @@ using ::testing::Eq;
using ::testing::Not;
using ::testing::ResultOf;
class HloElementTypeConverterTest : public HloTestBase {
public:
std::unique_ptr<HloModule> CreateModuleFromHloString(
const string& hlo_string) {
return HloRunner::CreateModuleFromString(hlo_string,
GetDebugOptionsForTest())
.ValueOrDie();
}
};
using HloElementTypeConverterTest = HloTestBase;
TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) {
const string& hlo_string = R"(
@ -47,7 +39,7 @@ TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) {
custom_call_target="foo"
}
)";
auto module = CreateModuleFromHloString(hlo_string);
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
HloElementTypeConverter type_converter(BF16, F32);
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
EXPECT_FALSE(converted);
@ -63,7 +55,7 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) {
outfeed = token[] outfeed(infeed.data, token0)
}
)";
auto module = CreateModuleFromHloString(hlo_string);
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
HloElementTypeConverter type_converter(BF16, F32);
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
EXPECT_FALSE(converted);
@ -73,17 +65,16 @@ TEST_F(HloElementTypeConverterTest, OperationsInNestedTuplesConverted) {
const string& hlo_string = R"(
HloModule NestedTuples
ENTRY NestedTuples.v5 {
constant.4 = bf16[] constant(42)
constant.2 = f32[2]{0} constant({1, 2})
constant.3 = bf16[] constant(42)
add = bf16[] add(constant.2, constant.3)
tuple = (f32[2]{0}, bf16[]) tuple(constant.2, add)
constant.3 = bf16[2]{0} constant({42, 42})
add = bf16[2]{0} add(constant.2, constant.3)
tuple = (f32[2]{0}, bf16[2]{0}) tuple(constant.2, add)
constant.5 = bf16[2]{0} constant({22, 44})
ROOT tuple.1 = ((f32[2]{0}, bf16[]), bf16[2]{0}) tuple(tuple, constant.5)
ROOT tuple.1 = ((f32[2]{0}, bf16[2]{0}), bf16[2]{0}) tuple(tuple, constant.5)
}
)";
auto module = CreateModuleFromHloString(hlo_string);
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
HloElementTypeConverter type_converter(BF16, F32);
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
EXPECT_TRUE(converted);
@ -111,7 +102,7 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) {
}
)";
auto module = CreateModuleFromHloString(hlo_string);
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
HloElementTypeConverter type_converter(BF16, F32);
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
EXPECT_TRUE(converted);
@ -135,7 +126,7 @@ ENTRY main {
ROOT rng = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform
}
)";
auto module = CreateModuleFromHloString(hlo_string);
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
HloElementTypeConverter type_converter(BF16, F32);
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
EXPECT_TRUE(converted);
@ -161,7 +152,7 @@ ENTRY main {
ROOT rng1 = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), control-predecessors={%rng0}, distribution=rng_uniform
}
)";
auto module = CreateModuleFromHloString(hlo_string);
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
HloElementTypeConverter type_converter(BF16, F32);
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));

View File

@ -224,9 +224,7 @@ XLA_TEST_F(ConstantsHloTest, DISABLED_ON_GPU(BitcastOfConstant)) {
ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func
}
)";
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param = LiteralUtil::CreateR0<int32>(1);
auto result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(param, result));

View File

@ -208,9 +208,7 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation
}
)";
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param = LiteralUtil::MakeTupleOwned(
LiteralUtil::MakeTupleOwned(
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)),
@ -241,9 +239,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
const = f32[4] constant({0, 0, 0, 0})
ROOT select = f32[4] select(gte0, gte1, const)
})";
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
@ -273,9 +269,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
p1 = f32[3] parameter(0)
ROOT map = f32[3] map(p1), to_apply=map_computation
})";
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
@ -315,9 +309,7 @@ XLA_TEST_F(MultiOutputFusionTest,
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
calls=fused_reduce
})");
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
@ -346,9 +338,7 @@ XLA_TEST_F(MultiOutputFusionTest,
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
calls=fused_reduce
})");
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
@ -378,9 +368,7 @@ XLA_TEST_F(MultiOutputFusionTest,
ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput,
calls=fused_reduce
})");
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
@ -410,9 +398,7 @@ XLA_TEST_F(MultiOutputFusionTest,
ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p),
kind=kInput, calls=fused_reduce
})");
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
@ -443,9 +429,7 @@ XLA_TEST_F(MultiOutputFusionTest,
ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p),
kind=kInput, calls=fused_reduce
})");
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
@ -478,9 +462,7 @@ XLA_TEST_F(MultiOutputFusionTest,
ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p),
kind=kInput, calls=fused_reduce
})");
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
Literal result = ExecuteNoHloPasses(std::move(module), {&param});
@ -513,9 +495,7 @@ XLA_TEST_F(MultiOutputFusionTest,
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput,
calls=fused_reduce
})");
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
auto init1 = LiteralUtil::CreateR0<float>(5);
@ -549,9 +529,7 @@ XLA_TEST_F(MultiOutputFusionTest,
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p),
kind=kInput, calls=fused_reduce
})");
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param = LiteralUtil::CreateR3<Eigen::half>(
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});

View File

@ -525,9 +525,7 @@ XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy)
}
)";
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
auto result = ExecuteNoHloPasses(std::move(module), {&param});
@ -559,9 +557,7 @@ XLA_TEST_F(TupleHloTest,
ROOT outfeed = token[] outfeed(tuple, token0)
}
)";
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param0 = LiteralUtil::CreateR1<float>({1, 2});
auto param1 = LiteralUtil::CreateR1<float>({2, 3});
auto param4 = LiteralUtil::CreateR0<bool>(false);

View File

@ -36,9 +36,8 @@ ENTRY %entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> hlo_module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
ParseAndReturnVerifiedModule(hlo_string));
{
auto extracted_module =
@ -75,9 +74,8 @@ ENTRY %entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> hlo_module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
ParseAndReturnVerifiedModule(hlo_string));
{
auto extracted_module =
@ -120,9 +118,8 @@ ENTRY %entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> hlo_module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
ParseAndReturnVerifiedModule(hlo_string));
{
auto extracted_module =