Use VerifiedHloModule in hlo_parser_test where applicable.
Don't enable it for tests with sparse layout (this fails verification with "Sparse arrays are not yet fully supported"). Fix bugs in HLO that were discoved with this. PiperOrigin-RevId: 275807649 Change-Id: Id4021d689c071aaedf957a7a435328641cf85402
This commit is contained in:
parent
94e7e37a60
commit
2bfc99464c
@ -4110,12 +4110,15 @@ tf_cc_test(
|
|||||||
":hlo_parser",
|
":hlo_parser",
|
||||||
":pattern_matcher",
|
":pattern_matcher",
|
||||||
":pattern_matcher_gmock",
|
":pattern_matcher_gmock",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test_helpers",
|
"//tensorflow/compiler/xla:test_helpers",
|
||||||
"//tensorflow/compiler/xla:window_util",
|
"//tensorflow/compiler/xla:window_util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/tests:verified_hlo_module",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main", # fixdeps: keep
|
"//tensorflow/core:test_main", # fixdeps: keep
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
@ -23,6 +26,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||||
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
|
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/window_util.h"
|
#include "tensorflow/compiler/xla/window_util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -37,6 +42,7 @@ using absl::string_view;
|
|||||||
struct TestData {
|
struct TestData {
|
||||||
string test_name;
|
string test_name;
|
||||||
string module_string;
|
string module_string;
|
||||||
|
bool enable_verification = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
|
string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
|
||||||
@ -451,10 +457,10 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
|
|||||||
"ConvolutionR2",
|
"ConvolutionR2",
|
||||||
R"(HloModule ConvolveR2_module
|
R"(HloModule ConvolveR2_module
|
||||||
|
|
||||||
ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] {
|
ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[2,2]) -> f32[1,2] {
|
||||||
%input = f32[1,2]{1,0} parameter(0)
|
%input = f32[1,2]{1,0} parameter(0)
|
||||||
%filter = f32[1,1]{1,0} parameter(1)
|
%filter = f32[2,2]{1,0} parameter(1)
|
||||||
ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf
|
ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[2,2]{1,0} %filter), dim_labels=bf_io->bf
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
@ -780,10 +786,10 @@ ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
|
|||||||
"Pad",
|
"Pad",
|
||||||
R"(HloModule Pad1DS3Array_module
|
R"(HloModule Pad1DS3Array_module
|
||||||
|
|
||||||
ENTRY %Pad1DS3Array.v3 () -> f32[8] {
|
ENTRY %Pad1DS3Array.v3 () -> f32[7] {
|
||||||
%constant = f32[3]{0} constant({1, 2, 3})
|
%constant = f32[3]{0} constant({1, 2, 3})
|
||||||
%constant.1 = f32[] constant(0.1)
|
%constant.1 = f32[] constant(0.1)
|
||||||
ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
|
ROOT %pad = f32[7]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
@ -806,10 +812,10 @@ ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
|
|||||||
"PadHasNegativePadding",
|
"PadHasNegativePadding",
|
||||||
R"(HloModule PadHasNegativePadding_module
|
R"(HloModule PadHasNegativePadding_module
|
||||||
|
|
||||||
ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,29] {
|
ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,35] {
|
||||||
%input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
|
%input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
|
||||||
%constant = f32[] constant(-5.123)
|
%constant = f32[] constant(-5.123)
|
||||||
ROOT %pad = f32[1,15,6,3,29]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
|
ROOT %pad = f32[1,15,6,3,35]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
@ -842,7 +848,8 @@ ENTRY %sparse () -> f32[2,3,4] {
|
|||||||
ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3})
|
ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3})
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)",
|
||||||
|
/*enable_verification=*/false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"SparseC128",
|
"SparseC128",
|
||||||
@ -852,7 +859,8 @@ ENTRY %sparse () -> c128[2,3,4] {
|
|||||||
ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)})
|
ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)})
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)",
|
||||||
|
/*enable_verification=*/false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"SparseEmpty",
|
"SparseEmpty",
|
||||||
@ -862,7 +870,8 @@ ENTRY %sparse_f32_empty () -> f32[2,3,4] {
|
|||||||
ROOT %foo = f32[2,3,4]sparse{10} constant({})
|
ROOT %foo = f32[2,3,4]sparse{10} constant({})
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)",
|
||||||
|
/*enable_verification=*/false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"SparseR1",
|
"SparseR1",
|
||||||
@ -872,7 +881,8 @@ ENTRY %sparse_f32_r1 () -> f32[9] {
|
|||||||
ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6})
|
ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6})
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)",
|
||||||
|
/*enable_verification=*/false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"Gather",
|
"Gather",
|
||||||
@ -1152,7 +1162,7 @@ max_argmax {
|
|||||||
|
|
||||||
ENTRY reduce_entry {
|
ENTRY reduce_entry {
|
||||||
values = f32[1024]{0} parameter(0)
|
values = f32[1024]{0} parameter(0)
|
||||||
indices = f32[1024]{0} parameter(1)
|
indices = s32[1024]{0} parameter(1)
|
||||||
init_value = f32[] constant(-inf)
|
init_value = f32[] constant(-inf)
|
||||||
init_index = s32[] constant(-1)
|
init_index = s32[] constant(-1)
|
||||||
ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
|
ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
|
||||||
@ -1410,8 +1420,8 @@ R"(HloModule dot
|
|||||||
|
|
||||||
ENTRY dot {
|
ENTRY dot {
|
||||||
a = f32[2,10]{1,0} parameter(0)
|
a = f32[2,10]{1,0} parameter(0)
|
||||||
b = f32[10,3]{1,0} parameter(1)
|
b = f32[10,2]{1,0} parameter(1)
|
||||||
ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
ROOT dot = f32[2]{0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
@ -1490,7 +1500,7 @@ R"(HloModule AllToAll
|
|||||||
|
|
||||||
ENTRY AllToAll {
|
ENTRY AllToAll {
|
||||||
input = f32[128,32]{0,1} parameter(0)
|
input = f32[128,32]{0,1} parameter(0)
|
||||||
ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={}
|
ROOT a2a = (f32[128,32]{0,1}) all-to-all(input), replica_groups={}
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
@ -1501,8 +1511,9 @@ ENTRY AllToAll {
|
|||||||
R"(HloModule AllToAllWithSubgroups
|
R"(HloModule AllToAllWithSubgroups
|
||||||
|
|
||||||
ENTRY AllToAllWithSubgroups {
|
ENTRY AllToAllWithSubgroups {
|
||||||
input = f32[128,32]{0,1} parameter(0)
|
p0 = f32[128,32]{0,1} parameter(0)
|
||||||
ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}
|
p1 = f32[128,32]{0,1} parameter(1)
|
||||||
|
ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
@ -1684,13 +1695,24 @@ class HloParameterizedParserTest
|
|||||||
: public ::testing::Test,
|
: public ::testing::Test,
|
||||||
public ::testing::WithParamInterface<TestData> {
|
public ::testing::WithParamInterface<TestData> {
|
||||||
protected:
|
protected:
|
||||||
// Expects "ToString(ParseAndReturnUnverifiedModule(string)) == string", that
|
// Expects "ToString(ParseHloModule(string)) == string", that is, parses the
|
||||||
// is, parses the string, asserts that it succeeded, stringifies the parsed
|
// string, asserts that it succeeded, stringifies the parsed module, and
|
||||||
// module, and checks that it equals the original string.
|
// checks that it equals the original string.
|
||||||
void ExpectEqual() {
|
void ExpectEqual() {
|
||||||
|
std::unique_ptr<HloModule> module;
|
||||||
const string& original = GetParam().module_string;
|
const string& original = GetParam().module_string;
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
if (GetParam().enable_verification) {
|
||||||
ParseAndReturnUnverifiedModule(original));
|
auto verified_module = absl::make_unique<VerifiedHloModule>(
|
||||||
|
GetParam().test_name, HloModuleConfig(),
|
||||||
|
/*verifier_layout_sensitive=*/false,
|
||||||
|
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
||||||
|
ShapeUtil::ByteSizeOfElements);
|
||||||
|
TF_ASSERT_OK(ParseHloString(original, verified_module.get()));
|
||||||
|
TF_ASSERT_OK(verified_module->Verify());
|
||||||
|
module = std::move(verified_module);
|
||||||
|
} else {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original));
|
||||||
|
}
|
||||||
if (proto_round_trip) {
|
if (proto_round_trip) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
|
TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
|
||||||
module->ToProto(), module->config()));
|
module->ToProto(), module->config()));
|
||||||
@ -1738,6 +1760,18 @@ class HloParserTest : public ::testing::Test {
|
|||||||
EXPECT_TRUE(absl::StrContains(s, expected))
|
EXPECT_TRUE(absl::StrContains(s, expected))
|
||||||
<< "'" << s << "' does not contain '" << expected << "'";
|
<< "'" << s << "' does not contain '" << expected << "'";
|
||||||
}
|
}
|
||||||
|
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||||
|
absl::string_view hlo_text) {
|
||||||
|
auto module = absl::make_unique<VerifiedHloModule>(
|
||||||
|
::testing::UnitTest::GetInstance()->current_test_info()->name(),
|
||||||
|
HloModuleConfig(),
|
||||||
|
/*verifier_layout_sensitive=*/false,
|
||||||
|
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
||||||
|
ShapeUtil::ByteSizeOfElements);
|
||||||
|
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
|
||||||
|
TF_RETURN_IF_ERROR(module->Verify());
|
||||||
|
return std::move(module);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(HloParserTest, Empty) {
|
TEST_F(HloParserTest, Empty) {
|
||||||
@ -1813,7 +1847,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
auto result = ParseAndReturnUnverifiedModule(original);
|
auto result = ParseAndReturnVerifiedModule(original);
|
||||||
TF_EXPECT_OK(result.status());
|
TF_EXPECT_OK(result.status());
|
||||||
// Constant instructions have no name. The string will be parsed successfully
|
// Constant instructions have no name. The string will be parsed successfully
|
||||||
// but the constant names will not be exactly the same.
|
// but the constant names will not be exactly the same.
|
||||||
@ -1824,7 +1858,7 @@ TEST_F(HloParserTest, ConfigurationField) {
|
|||||||
ENTRY %configuration_test() -> s32[] {
|
ENTRY %configuration_test() -> s32[] {
|
||||||
%constant = s32[] constant(42), backend_config="foo bar"
|
%constant = s32[] constant(42), backend_config="foo bar"
|
||||||
})";
|
})";
|
||||||
auto result = ParseAndReturnUnverifiedModule(original);
|
auto result = ParseAndReturnVerifiedModule(original);
|
||||||
TF_ASSERT_OK(result.status());
|
TF_ASSERT_OK(result.status());
|
||||||
EXPECT_EQ("foo bar", result.ValueOrDie()
|
EXPECT_EQ("foo bar", result.ValueOrDie()
|
||||||
->entry_computation()
|
->entry_computation()
|
||||||
@ -1896,7 +1930,7 @@ TEST_F(HloParserTest, ConstantBf16NoOverflow) {
|
|||||||
ENTRY test {
|
ENTRY test {
|
||||||
ROOT c = bf16[] constant(-65505)
|
ROOT c = bf16[] constant(-65505)
|
||||||
})";
|
})";
|
||||||
EXPECT_EQ(Status::OK(), ParseAndReturnUnverifiedModule(original).status());
|
EXPECT_EQ(Status::OK(), ParseAndReturnVerifiedModule(original).status());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloParserTest, ConstantBf16Overflow) {
|
TEST_F(HloParserTest, ConstantBf16Overflow) {
|
||||||
@ -2004,7 +2038,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
auto result = ParseAndReturnUnverifiedModule(original);
|
auto result = ParseAndReturnVerifiedModule(original);
|
||||||
TF_EXPECT_OK(result.status());
|
TF_EXPECT_OK(result.status());
|
||||||
// The string will be parsed successfully but the output strings are not
|
// The string will be parsed successfully but the output strings are not
|
||||||
// exactly the same, because "3e2" is parsed into value 300 and will be
|
// exactly the same, because "3e2" is parsed into value 300 and will be
|
||||||
@ -2012,14 +2046,14 @@ ENTRY %ConstantWithExp.v4 () -> f32[] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloParserTest, ShortConstant) {
|
TEST_F(HloParserTest, ShortConstant) {
|
||||||
const string original = R"(HloModule ShortCOnstant_module
|
const string original = R"(HloModule ShortConstant_module
|
||||||
|
|
||||||
ENTRY %ShortConstant.v4 () -> f32[67,89] {
|
ENTRY %ShortConstant.v4 () -> f32[67,89] {
|
||||||
ROOT %constant.1 = f32[67,89]{1,0} constant({...})
|
ROOT %constant.1 = f32[67,89]{1,0} constant({...})
|
||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
auto result = ParseAndReturnUnverifiedModule(original);
|
auto result = ParseAndReturnVerifiedModule(original);
|
||||||
TF_EXPECT_OK(result.status());
|
TF_EXPECT_OK(result.status());
|
||||||
EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
|
EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
|
||||||
}
|
}
|
||||||
@ -2027,15 +2061,15 @@ ENTRY %ShortConstant.v4 () -> f32[67,89] {
|
|||||||
TEST_F(HloParserTest, AttibutesAnyOrder) {
|
TEST_F(HloParserTest, AttibutesAnyOrder) {
|
||||||
const string original = R"(HloModule any_order_module
|
const string original = R"(HloModule any_order_module
|
||||||
|
|
||||||
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
|
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,4,1] {
|
||||||
%input = f32[1,2,1]{2,1,0} parameter(0)
|
%input = f32[1,2,1]{2,1,0} parameter(0)
|
||||||
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
|
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
|
||||||
%filter = f32[1,1,1]{2,1,0} parameter(1)
|
%filter = f32[1,1,1]{2,1,0} parameter(1)
|
||||||
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
|
ROOT %convolution = f32[1,4,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=1}
|
||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
TF_EXPECT_OK(ParseAndReturnUnverifiedModule(original).status());
|
TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloParserTest, InvalidDimLabels) {
|
TEST_F(HloParserTest, InvalidDimLabels) {
|
||||||
@ -2127,7 +2161,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
TF_EXPECT_OK(ParseAndReturnUnverifiedModule(original).status());
|
TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
|
TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
|
||||||
@ -2154,7 +2188,7 @@ ENTRY %test_comma.v4 () -> f32[] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
TF_EXPECT_OK(ParseAndReturnUnverifiedModule(original).status());
|
TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
|
TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
|
||||||
@ -2184,7 +2218,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
|
|||||||
ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
|
ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
|
||||||
})";
|
})";
|
||||||
|
|
||||||
auto module = ParseAndReturnUnverifiedModule(original);
|
auto module = ParseAndReturnVerifiedModule(original);
|
||||||
TF_ASSERT_OK(module.status());
|
TF_ASSERT_OK(module.status());
|
||||||
auto program_layout = module.ValueOrDie()->entry_computation_layout();
|
auto program_layout = module.ValueOrDie()->entry_computation_layout();
|
||||||
ASSERT_EQ(program_layout.parameter_count(), 1);
|
ASSERT_EQ(program_layout.parameter_count(), 1);
|
||||||
@ -2207,7 +2241,7 @@ c1 {
|
|||||||
c2 {
|
c2 {
|
||||||
const2 = f32[1]{0} constant({67890})
|
const2 = f32[1]{0} constant({67890})
|
||||||
})";
|
})";
|
||||||
auto module = ParseAndReturnUnverifiedModule(original);
|
auto module = ParseAndReturnVerifiedModule(original);
|
||||||
TF_ASSERT_OK(module.status());
|
TF_ASSERT_OK(module.status());
|
||||||
EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
|
EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
|
||||||
}
|
}
|
||||||
@ -2218,7 +2252,7 @@ ENTRY consts {
|
|||||||
first = f32[1]{0} constant({12345})
|
first = f32[1]{0} constant({12345})
|
||||||
last = f32[1]{0} constant({67890})
|
last = f32[1]{0} constant({67890})
|
||||||
})";
|
})";
|
||||||
auto module = ParseAndReturnUnverifiedModule(original);
|
auto module = ParseAndReturnVerifiedModule(original);
|
||||||
TF_ASSERT_OK(module.status());
|
TF_ASSERT_OK(module.status());
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
module.ValueOrDie()->entry_computation()->root_instruction()->name(),
|
module.ValueOrDie()->entry_computation()->root_instruction()->name(),
|
||||||
@ -2238,7 +2272,7 @@ ENTRY /*comment*/ c1 {
|
|||||||
/* something else */
|
/* something else */
|
||||||
|
|
||||||
)";
|
)";
|
||||||
auto module = ParseAndReturnUnverifiedModule(original);
|
auto module = ParseAndReturnVerifiedModule(original);
|
||||||
TF_ASSERT_OK(module.status());
|
TF_ASSERT_OK(module.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2257,7 +2291,7 @@ d
|
|||||||
|
|
||||||
*/
|
*/
|
||||||
})";
|
})";
|
||||||
auto module = ParseAndReturnUnverifiedModule(original);
|
auto module = ParseAndReturnVerifiedModule(original);
|
||||||
TF_ASSERT_OK(module.status());
|
TF_ASSERT_OK(module.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2281,7 +2315,7 @@ ENTRY c1 {
|
|||||||
// Foo bar
|
// Foo bar
|
||||||
ROOT const1 = f32[1]{0} constant({12345}) // Something else
|
ROOT const1 = f32[1]{0} constant({12345}) // Something else
|
||||||
})";
|
})";
|
||||||
auto module = ParseAndReturnUnverifiedModule(original);
|
auto module = ParseAndReturnVerifiedModule(original);
|
||||||
TF_ASSERT_OK(module.status());
|
TF_ASSERT_OK(module.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2289,7 +2323,7 @@ TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
|
|||||||
const string original =
|
const string original =
|
||||||
"HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
|
"HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
|
||||||
"bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
|
"bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
|
||||||
auto module = ParseAndReturnUnverifiedModule(original);
|
auto module = ParseAndReturnVerifiedModule(original);
|
||||||
TF_ASSERT_OK(module.status());
|
TF_ASSERT_OK(module.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2297,7 +2331,7 @@ TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
|
|||||||
const string original =
|
const string original =
|
||||||
"HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
|
"HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
|
||||||
"bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
|
"bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
|
||||||
auto module = ParseAndReturnUnverifiedModule(original);
|
auto module = ParseAndReturnVerifiedModule(original);
|
||||||
TF_ASSERT_OK(module.status());
|
TF_ASSERT_OK(module.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2375,8 +2409,7 @@ ENTRY ReduceR3ToR2 {
|
|||||||
ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
|
ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(original));
|
||||||
ParseAndReturnUnverifiedModule(original));
|
|
||||||
ASSERT_NE(module->entry_computation(), nullptr);
|
ASSERT_NE(module->entry_computation(), nullptr);
|
||||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
GmockMatch(m::Reduce()));
|
GmockMatch(m::Reduce()));
|
||||||
@ -2649,8 +2682,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
|||||||
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
|
||||||
ParseAndReturnUnverifiedModule(text));
|
|
||||||
ASSERT_FALSE(module->has_schedule());
|
ASSERT_FALSE(module->has_schedule());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2667,8 +2699,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
|||||||
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
|
||||||
ParseAndReturnUnverifiedModule(text));
|
|
||||||
ASSERT_FALSE(module->has_schedule());
|
ASSERT_FALSE(module->has_schedule());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2685,8 +2716,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
|||||||
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
|
||||||
ParseAndReturnUnverifiedModule(text));
|
|
||||||
ASSERT_TRUE(module->has_schedule());
|
ASSERT_TRUE(module->has_schedule());
|
||||||
TF_ASSERT_OK(module->schedule().Verify());
|
TF_ASSERT_OK(module->schedule().Verify());
|
||||||
EXPECT_EQ(module->schedule().sequences().size(), 1);
|
EXPECT_EQ(module->schedule().sequences().size(), 1);
|
||||||
@ -2714,8 +2744,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
|||||||
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
|
||||||
ParseAndReturnUnverifiedModule(text));
|
|
||||||
ASSERT_TRUE(module->has_schedule());
|
ASSERT_TRUE(module->has_schedule());
|
||||||
TF_ASSERT_OK(module->schedule().Verify());
|
TF_ASSERT_OK(module->schedule().Verify());
|
||||||
EXPECT_EQ(module->schedule().sequences().size(), 1);
|
EXPECT_EQ(module->schedule().sequences().size(), 1);
|
||||||
@ -2767,8 +2796,7 @@ ENTRY entry {
|
|||||||
ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0)
|
ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
|
||||||
ParseAndReturnUnverifiedModule(text));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloParserTest, ShapeMismatchInOperand) {
|
TEST_F(HloParserTest, ShapeMismatchInOperand) {
|
||||||
|
Loading…
Reference in New Issue
Block a user