Create VerifiedModule in tests with HloText.
This makes sure that the verifier is run in addition to the parser. The parser may allow something which is not actually valid HLO. Fix the bugs in tests that were found due to this change. PiperOrigin-RevId: 230885173
This commit is contained in:
parent
8a9c9c8f8f
commit
3296f42b0f
@ -2227,9 +2227,8 @@ TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
|
|||||||
ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
|
ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
auto module,
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
|
||||||
|
|
||||||
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
|
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
|
||||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
|
@ -1855,8 +1855,7 @@ ENTRY %TokensShouldNotBeCopied () -> s32[] {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
HloRunner::CreateModuleFromString(
|
ParseAndReturnVerifiedModule(module_string));
|
||||||
module_string, GetDebugOptionsForTest()));
|
|
||||||
InsertCopies(module.get());
|
InsertCopies(module.get());
|
||||||
|
|
||||||
// There should be no copies added because tokens should not be copied.
|
// There should be no copies added because tokens should not be copied.
|
||||||
@ -2119,8 +2118,7 @@ ENTRY TestComputation {
|
|||||||
ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
|
ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module_or_status =
|
auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
|
||||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
|
|
||||||
auto module = module_or_status.ConsumeValueOrDie();
|
auto module = module_or_status.ConsumeValueOrDie();
|
||||||
InsertCopies(module.get());
|
InsertCopies(module.get());
|
||||||
}
|
}
|
||||||
@ -2220,8 +2218,7 @@ ENTRY TestComputation {
|
|||||||
ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
|
ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module_or_status =
|
auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
|
||||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
|
|
||||||
auto module = module_or_status.ConsumeValueOrDie();
|
auto module = module_or_status.ConsumeValueOrDie();
|
||||||
InsertCopies(module.get());
|
InsertCopies(module.get());
|
||||||
}
|
}
|
||||||
@ -2238,7 +2235,7 @@ cond.inner {
|
|||||||
|
|
||||||
body.inner {
|
body.inner {
|
||||||
param.body.inner = pred[] parameter(0)
|
param.body.inner = pred[] parameter(0)
|
||||||
ROOT neg = pred[] negate(param.body.inner)
|
ROOT not = pred[] not(param.body.inner)
|
||||||
}
|
}
|
||||||
|
|
||||||
cond.outer {
|
cond.outer {
|
||||||
@ -2255,9 +2252,8 @@ ENTRY TestComputation {
|
|||||||
ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
|
ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
std::unique_ptr<HloModule> module,
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
|
||||||
InsertCopies(module.get());
|
InsertCopies(module.get());
|
||||||
|
|
||||||
// There should only be a single copy inserted, and it's in the entry
|
// There should only be a single copy inserted, and it's in the entry
|
||||||
|
@ -73,15 +73,14 @@ ENTRY TestComputation {
|
|||||||
abs = f32[] abs(arg)
|
abs = f32[] abs(arg)
|
||||||
add = f32[] add(arg, gte)
|
add = f32[] add(arg, gte)
|
||||||
broadcast = f32[42] broadcast(add), dimensions={}
|
broadcast = f32[42] broadcast(add), dimensions={}
|
||||||
slice = f32[0] slice(broadcast), slice={[1:2]}
|
slice = f32[1] slice(broadcast), slice={[1:2]}
|
||||||
copy = f32[] copy(arg)
|
copy = f32[] copy(arg)
|
||||||
eq = pred[] equal-to(arg, gte)
|
eq = pred[] equal-to(arg, gte)
|
||||||
neg = f32[] negate(arg)
|
neg = f32[] negate(arg)
|
||||||
ROOT convert = f64[] convert(f32[] arg)
|
ROOT convert = f64[] convert(f32[] arg)
|
||||||
})";
|
})";
|
||||||
std::unique_ptr<HloModule> module =
|
std::unique_ptr<HloModule> module =
|
||||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())
|
ParseAndReturnVerifiedModule(hlo_string).ConsumeValueOrDie();
|
||||||
.ConsumeValueOrDie();
|
|
||||||
ElementwiseTestVisitor visitor;
|
ElementwiseTestVisitor visitor;
|
||||||
TF_EXPECT_OK(module->entry_computation()->Accept(&visitor));
|
TF_EXPECT_OK(module->entry_computation()->Accept(&visitor));
|
||||||
}
|
}
|
||||||
|
@ -28,15 +28,7 @@ using ::testing::Eq;
|
|||||||
using ::testing::Not;
|
using ::testing::Not;
|
||||||
using ::testing::ResultOf;
|
using ::testing::ResultOf;
|
||||||
|
|
||||||
class HloElementTypeConverterTest : public HloTestBase {
|
using HloElementTypeConverterTest = HloTestBase;
|
||||||
public:
|
|
||||||
std::unique_ptr<HloModule> CreateModuleFromHloString(
|
|
||||||
const string& hlo_string) {
|
|
||||||
return HloRunner::CreateModuleFromString(hlo_string,
|
|
||||||
GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) {
|
TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) {
|
||||||
const string& hlo_string = R"(
|
const string& hlo_string = R"(
|
||||||
@ -47,7 +39,7 @@ TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) {
|
|||||||
custom_call_target="foo"
|
custom_call_target="foo"
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module = CreateModuleFromHloString(hlo_string);
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
HloElementTypeConverter type_converter(BF16, F32);
|
HloElementTypeConverter type_converter(BF16, F32);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||||
EXPECT_FALSE(converted);
|
EXPECT_FALSE(converted);
|
||||||
@ -63,7 +55,7 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) {
|
|||||||
outfeed = token[] outfeed(infeed.data, token0)
|
outfeed = token[] outfeed(infeed.data, token0)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module = CreateModuleFromHloString(hlo_string);
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
HloElementTypeConverter type_converter(BF16, F32);
|
HloElementTypeConverter type_converter(BF16, F32);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||||
EXPECT_FALSE(converted);
|
EXPECT_FALSE(converted);
|
||||||
@ -73,17 +65,16 @@ TEST_F(HloElementTypeConverterTest, OperationsInNestedTuplesConverted) {
|
|||||||
const string& hlo_string = R"(
|
const string& hlo_string = R"(
|
||||||
HloModule NestedTuples
|
HloModule NestedTuples
|
||||||
ENTRY NestedTuples.v5 {
|
ENTRY NestedTuples.v5 {
|
||||||
constant.4 = bf16[] constant(42)
|
|
||||||
constant.2 = f32[2]{0} constant({1, 2})
|
constant.2 = f32[2]{0} constant({1, 2})
|
||||||
constant.3 = bf16[] constant(42)
|
constant.3 = bf16[2]{0} constant({42, 42})
|
||||||
add = bf16[] add(constant.2, constant.3)
|
add = bf16[2]{0} add(constant.2, constant.3)
|
||||||
tuple = (f32[2]{0}, bf16[]) tuple(constant.2, add)
|
tuple = (f32[2]{0}, bf16[2]{0}) tuple(constant.2, add)
|
||||||
constant.5 = bf16[2]{0} constant({22, 44})
|
constant.5 = bf16[2]{0} constant({22, 44})
|
||||||
ROOT tuple.1 = ((f32[2]{0}, bf16[]), bf16[2]{0}) tuple(tuple, constant.5)
|
ROOT tuple.1 = ((f32[2]{0}, bf16[2]{0}), bf16[2]{0}) tuple(tuple, constant.5)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
auto module = CreateModuleFromHloString(hlo_string);
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
HloElementTypeConverter type_converter(BF16, F32);
|
HloElementTypeConverter type_converter(BF16, F32);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||||
EXPECT_TRUE(converted);
|
EXPECT_TRUE(converted);
|
||||||
@ -111,7 +102,7 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
auto module = CreateModuleFromHloString(hlo_string);
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
HloElementTypeConverter type_converter(BF16, F32);
|
HloElementTypeConverter type_converter(BF16, F32);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||||
EXPECT_TRUE(converted);
|
EXPECT_TRUE(converted);
|
||||||
@ -135,7 +126,7 @@ ENTRY main {
|
|||||||
ROOT rng = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform
|
ROOT rng = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module = CreateModuleFromHloString(hlo_string);
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
HloElementTypeConverter type_converter(BF16, F32);
|
HloElementTypeConverter type_converter(BF16, F32);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||||
EXPECT_TRUE(converted);
|
EXPECT_TRUE(converted);
|
||||||
@ -161,7 +152,7 @@ ENTRY main {
|
|||||||
ROOT rng1 = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), control-predecessors={%rng0}, distribution=rng_uniform
|
ROOT rng1 = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), control-predecessors={%rng0}, distribution=rng_uniform
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module = CreateModuleFromHloString(hlo_string);
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
|
|
||||||
HloElementTypeConverter type_converter(BF16, F32);
|
HloElementTypeConverter type_converter(BF16, F32);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
|
||||||
|
@ -224,9 +224,7 @@ XLA_TEST_F(ConstantsHloTest, DISABLED_ON_GPU(BitcastOfConstant)) {
|
|||||||
ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func
|
ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param = LiteralUtil::CreateR0<int32>(1);
|
auto param = LiteralUtil::CreateR0<int32>(1);
|
||||||
auto result = ExecuteNoHloPasses(std::move(module), {¶m});
|
auto result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||||
EXPECT_TRUE(LiteralTestUtil::Equal(param, result));
|
EXPECT_TRUE(LiteralTestUtil::Equal(param, result));
|
||||||
|
@ -208,9 +208,7 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
|
|||||||
ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation
|
ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param = LiteralUtil::MakeTupleOwned(
|
auto param = LiteralUtil::MakeTupleOwned(
|
||||||
LiteralUtil::MakeTupleOwned(
|
LiteralUtil::MakeTupleOwned(
|
||||||
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)),
|
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)),
|
||||||
@ -241,9 +239,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
|
|||||||
const = f32[4] constant({0, 0, 0, 0})
|
const = f32[4] constant({0, 0, 0, 0})
|
||||||
ROOT select = f32[4] select(gte0, gte1, const)
|
ROOT select = f32[4] select(gte0, gte1, const)
|
||||||
})";
|
})";
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
|
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
|
||||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||||
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
|
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
|
||||||
@ -273,9 +269,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
|
|||||||
p1 = f32[3] parameter(0)
|
p1 = f32[3] parameter(0)
|
||||||
ROOT map = f32[3] map(p1), to_apply=map_computation
|
ROOT map = f32[3] map(p1), to_apply=map_computation
|
||||||
})";
|
})";
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
|
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
|
||||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||||
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
|
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
|
||||||
@ -315,9 +309,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
|||||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
|
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
|
||||||
calls=fused_reduce
|
calls=fused_reduce
|
||||||
})");
|
})");
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param =
|
auto param =
|
||||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||||
@ -346,9 +338,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
|||||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
|
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
|
||||||
calls=fused_reduce
|
calls=fused_reduce
|
||||||
})");
|
})");
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param =
|
auto param =
|
||||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||||
@ -378,9 +368,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
|||||||
ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput,
|
ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput,
|
||||||
calls=fused_reduce
|
calls=fused_reduce
|
||||||
})");
|
})");
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param =
|
auto param =
|
||||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||||
@ -410,9 +398,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
|||||||
ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p),
|
ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p),
|
||||||
kind=kInput, calls=fused_reduce
|
kind=kInput, calls=fused_reduce
|
||||||
})");
|
})");
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param =
|
auto param =
|
||||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||||
@ -443,9 +429,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
|||||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p),
|
ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p),
|
||||||
kind=kInput, calls=fused_reduce
|
kind=kInput, calls=fused_reduce
|
||||||
})");
|
})");
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param =
|
auto param =
|
||||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||||
@ -478,9 +462,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
|||||||
ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p),
|
ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p),
|
||||||
kind=kInput, calls=fused_reduce
|
kind=kInput, calls=fused_reduce
|
||||||
})");
|
})");
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param =
|
auto param =
|
||||||
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||||
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
Literal result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||||
@ -513,9 +495,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
|||||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput,
|
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput,
|
||||||
calls=fused_reduce
|
calls=fused_reduce
|
||||||
})");
|
})");
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param =
|
auto param =
|
||||||
LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
|
||||||
auto init1 = LiteralUtil::CreateR0<float>(5);
|
auto init1 = LiteralUtil::CreateR0<float>(5);
|
||||||
@ -549,9 +529,7 @@ XLA_TEST_F(MultiOutputFusionTest,
|
|||||||
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p),
|
ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p),
|
||||||
kind=kInput, calls=fused_reduce
|
kind=kInput, calls=fused_reduce
|
||||||
})");
|
})");
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param = LiteralUtil::CreateR3<Eigen::half>(
|
auto param = LiteralUtil::CreateR3<Eigen::half>(
|
||||||
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
|
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
|
||||||
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
|
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
|
||||||
|
@ -525,9 +525,7 @@ XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
|
|||||||
ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy)
|
ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param =
|
auto param =
|
||||||
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
|
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
|
||||||
auto result = ExecuteNoHloPasses(std::move(module), {¶m});
|
auto result = ExecuteNoHloPasses(std::move(module), {¶m});
|
||||||
@ -559,9 +557,7 @@ XLA_TEST_F(TupleHloTest,
|
|||||||
ROOT outfeed = token[] outfeed(tuple, token0)
|
ROOT outfeed = token[] outfeed(tuple, token0)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module =
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
|
|
||||||
.ValueOrDie();
|
|
||||||
auto param0 = LiteralUtil::CreateR1<float>({1, 2});
|
auto param0 = LiteralUtil::CreateR1<float>({1, 2});
|
||||||
auto param1 = LiteralUtil::CreateR1<float>({2, 3});
|
auto param1 = LiteralUtil::CreateR1<float>({2, 3});
|
||||||
auto param4 = LiteralUtil::CreateR0<bool>(false);
|
auto param4 = LiteralUtil::CreateR0<bool>(false);
|
||||||
|
@ -36,9 +36,8 @@ ENTRY %entry {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
|
||||||
|
|
||||||
{
|
{
|
||||||
auto extracted_module =
|
auto extracted_module =
|
||||||
@ -75,9 +74,8 @@ ENTRY %entry {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
|
||||||
|
|
||||||
{
|
{
|
||||||
auto extracted_module =
|
auto extracted_module =
|
||||||
@ -120,9 +118,8 @@ ENTRY %entry {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
|
||||||
|
|
||||||
{
|
{
|
||||||
auto extracted_module =
|
auto extracted_module =
|
||||||
|
Loading…
Reference in New Issue
Block a user