Use VerifiedHloModule in hlo_parser_test where applicable.

Don't enable it for tests with sparse layout (this fails
verification with "Sparse arrays are not yet fully supported").
Fix bugs in HLO that were discoved with this.

PiperOrigin-RevId: 275807649
Change-Id: Id4021d689c071aaedf957a7a435328641cf85402
This commit is contained in:
Adrian Kuegel 2019-10-21 03:06:30 -07:00 committed by TensorFlower Gardener
parent 94e7e37a60
commit 2bfc99464c
2 changed files with 84 additions and 53 deletions

View File

@ -4110,12 +4110,15 @@ tf_cc_test(
":hlo_parser", ":hlo_parser",
":pattern_matcher", ":pattern_matcher",
":pattern_matcher_gmock", ":pattern_matcher_gmock",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:verified_hlo_module",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", # fixdeps: keep "//tensorflow/core:test_main", # fixdeps: keep
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -15,7 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_parser.h"
#include <memory>
#include <string> #include <string>
#include "absl/memory/memory.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
@ -23,6 +26,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
@ -37,6 +42,7 @@ using absl::string_view;
struct TestData { struct TestData {
string test_name; string test_name;
string module_string; string module_string;
bool enable_verification = true;
}; };
string TestDataToString(const ::testing::TestParamInfo<TestData>& data) { string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
@ -451,10 +457,10 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
"ConvolutionR2", "ConvolutionR2",
R"(HloModule ConvolveR2_module R"(HloModule ConvolveR2_module
ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[2,2]) -> f32[1,2] {
%input = f32[1,2]{1,0} parameter(0) %input = f32[1,2]{1,0} parameter(0)
%filter = f32[1,1]{1,0} parameter(1) %filter = f32[2,2]{1,0} parameter(1)
ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[2,2]{1,0} %filter), dim_labels=bf_io->bf
} }
)" )"
@ -780,10 +786,10 @@ ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
"Pad", "Pad",
R"(HloModule Pad1DS3Array_module R"(HloModule Pad1DS3Array_module
ENTRY %Pad1DS3Array.v3 () -> f32[8] { ENTRY %Pad1DS3Array.v3 () -> f32[7] {
%constant = f32[3]{0} constant({1, 2, 3}) %constant = f32[3]{0} constant({1, 2, 3})
%constant.1 = f32[] constant(0.1) %constant.1 = f32[] constant(0.1)
ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1 ROOT %pad = f32[7]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
} }
)" )"
@ -806,10 +812,10 @@ ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
"PadHasNegativePadding", "PadHasNegativePadding",
R"(HloModule PadHasNegativePadding_module R"(HloModule PadHasNegativePadding_module
ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,29] { ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,35] {
%input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0) %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
%constant = f32[] constant(-5.123) %constant = f32[] constant(-5.123)
ROOT %pad = f32[1,15,6,3,29]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3 ROOT %pad = f32[1,15,6,3,35]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
} }
)" )"
@ -842,7 +848,8 @@ ENTRY %sparse () -> f32[2,3,4] {
ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3}) ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3})
} }
)" )",
/*enable_verification=*/false
}, },
{ {
"SparseC128", "SparseC128",
@ -852,7 +859,8 @@ ENTRY %sparse () -> c128[2,3,4] {
ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)}) ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)})
} }
)" )",
/*enable_verification=*/false
}, },
{ {
"SparseEmpty", "SparseEmpty",
@ -862,7 +870,8 @@ ENTRY %sparse_f32_empty () -> f32[2,3,4] {
ROOT %foo = f32[2,3,4]sparse{10} constant({}) ROOT %foo = f32[2,3,4]sparse{10} constant({})
} }
)" )",
/*enable_verification=*/false,
}, },
{ {
"SparseR1", "SparseR1",
@ -872,7 +881,8 @@ ENTRY %sparse_f32_r1 () -> f32[9] {
ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6}) ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6})
} }
)" )",
/*enable_verification=*/false,
}, },
{ {
"Gather", "Gather",
@ -1152,7 +1162,7 @@ max_argmax {
ENTRY reduce_entry { ENTRY reduce_entry {
values = f32[1024]{0} parameter(0) values = f32[1024]{0} parameter(0)
indices = f32[1024]{0} parameter(1) indices = s32[1024]{0} parameter(1)
init_value = f32[] constant(-inf) init_value = f32[] constant(-inf)
init_index = s32[] constant(-1) init_index = s32[] constant(-1)
ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
@ -1410,8 +1420,8 @@ R"(HloModule dot
ENTRY dot { ENTRY dot {
a = f32[2,10]{1,0} parameter(0) a = f32[2,10]{1,0} parameter(0)
b = f32[10,3]{1,0} parameter(1) b = f32[10,2]{1,0} parameter(1)
ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0} ROOT dot = f32[2]{0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}
} }
)" )"
@ -1490,7 +1500,7 @@ R"(HloModule AllToAll
ENTRY AllToAll { ENTRY AllToAll {
input = f32[128,32]{0,1} parameter(0) input = f32[128,32]{0,1} parameter(0)
ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={} ROOT a2a = (f32[128,32]{0,1}) all-to-all(input), replica_groups={}
} }
)" )"
@ -1501,8 +1511,9 @@ ENTRY AllToAll {
R"(HloModule AllToAllWithSubgroups R"(HloModule AllToAllWithSubgroups
ENTRY AllToAllWithSubgroups { ENTRY AllToAllWithSubgroups {
input = f32[128,32]{0,1} parameter(0) p0 = f32[128,32]{0,1} parameter(0)
ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}} p1 = f32[128,32]{0,1} parameter(1)
ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
} }
)" )"
@ -1684,13 +1695,24 @@ class HloParameterizedParserTest
: public ::testing::Test, : public ::testing::Test,
public ::testing::WithParamInterface<TestData> { public ::testing::WithParamInterface<TestData> {
protected: protected:
// Expects "ToString(ParseAndReturnUnverifiedModule(string)) == string", that // Expects "ToString(ParseHloModule(string)) == string", that is, parses the
// is, parses the string, asserts that it succeeded, stringifies the parsed // string, asserts that it succeeded, stringifies the parsed module, and
// module, and checks that it equals the original string. // checks that it equals the original string.
void ExpectEqual() { void ExpectEqual() {
std::unique_ptr<HloModule> module;
const string& original = GetParam().module_string; const string& original = GetParam().module_string;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, if (GetParam().enable_verification) {
ParseAndReturnUnverifiedModule(original)); auto verified_module = absl::make_unique<VerifiedHloModule>(
GetParam().test_name, HloModuleConfig(),
/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
ShapeUtil::ByteSizeOfElements);
TF_ASSERT_OK(ParseHloString(original, verified_module.get()));
TF_ASSERT_OK(verified_module->Verify());
module = std::move(verified_module);
} else {
TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original));
}
if (proto_round_trip) { if (proto_round_trip) {
TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto( TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
module->ToProto(), module->config())); module->ToProto(), module->config()));
@ -1738,6 +1760,18 @@ class HloParserTest : public ::testing::Test {
EXPECT_TRUE(absl::StrContains(s, expected)) EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'"; << "'" << s << "' does not contain '" << expected << "'";
} }
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text) {
auto module = absl::make_unique<VerifiedHloModule>(
::testing::UnitTest::GetInstance()->current_test_info()->name(),
HloModuleConfig(),
/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
ShapeUtil::ByteSizeOfElements);
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
TF_RETURN_IF_ERROR(module->Verify());
return std::move(module);
}
}; };
TEST_F(HloParserTest, Empty) { TEST_F(HloParserTest, Empty) {
@ -1813,7 +1847,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] {
} }
)"; )";
auto result = ParseAndReturnUnverifiedModule(original); auto result = ParseAndReturnVerifiedModule(original);
TF_EXPECT_OK(result.status()); TF_EXPECT_OK(result.status());
// Constant instructions have no name. The string will be parsed successfully // Constant instructions have no name. The string will be parsed successfully
// but the constant names will not be exactly the same. // but the constant names will not be exactly the same.
@ -1824,7 +1858,7 @@ TEST_F(HloParserTest, ConfigurationField) {
ENTRY %configuration_test() -> s32[] { ENTRY %configuration_test() -> s32[] {
%constant = s32[] constant(42), backend_config="foo bar" %constant = s32[] constant(42), backend_config="foo bar"
})"; })";
auto result = ParseAndReturnUnverifiedModule(original); auto result = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(result.status()); TF_ASSERT_OK(result.status());
EXPECT_EQ("foo bar", result.ValueOrDie() EXPECT_EQ("foo bar", result.ValueOrDie()
->entry_computation() ->entry_computation()
@ -1896,7 +1930,7 @@ TEST_F(HloParserTest, ConstantBf16NoOverflow) {
ENTRY test { ENTRY test {
ROOT c = bf16[] constant(-65505) ROOT c = bf16[] constant(-65505)
})"; })";
EXPECT_EQ(Status::OK(), ParseAndReturnUnverifiedModule(original).status()); EXPECT_EQ(Status::OK(), ParseAndReturnVerifiedModule(original).status());
} }
TEST_F(HloParserTest, ConstantBf16Overflow) { TEST_F(HloParserTest, ConstantBf16Overflow) {
@ -2004,7 +2038,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] {
} }
)"; )";
auto result = ParseAndReturnUnverifiedModule(original); auto result = ParseAndReturnVerifiedModule(original);
TF_EXPECT_OK(result.status()); TF_EXPECT_OK(result.status());
// The string will be parsed successfully but the output strings are not // The string will be parsed successfully but the output strings are not
// exactly the same, because "3e2" is parsed into value 300 and will be // exactly the same, because "3e2" is parsed into value 300 and will be
@ -2012,14 +2046,14 @@ ENTRY %ConstantWithExp.v4 () -> f32[] {
} }
TEST_F(HloParserTest, ShortConstant) { TEST_F(HloParserTest, ShortConstant) {
const string original = R"(HloModule ShortCOnstant_module const string original = R"(HloModule ShortConstant_module
ENTRY %ShortConstant.v4 () -> f32[67,89] { ENTRY %ShortConstant.v4 () -> f32[67,89] {
ROOT %constant.1 = f32[67,89]{1,0} constant({...}) ROOT %constant.1 = f32[67,89]{1,0} constant({...})
} }
)"; )";
auto result = ParseAndReturnUnverifiedModule(original); auto result = ParseAndReturnVerifiedModule(original);
TF_EXPECT_OK(result.status()); TF_EXPECT_OK(result.status());
EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original); EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
} }
@ -2027,15 +2061,15 @@ ENTRY %ShortConstant.v4 () -> f32[67,89] {
TEST_F(HloParserTest, AttibutesAnyOrder) { TEST_F(HloParserTest, AttibutesAnyOrder) {
const string original = R"(HloModule any_order_module const string original = R"(HloModule any_order_module
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,4,1] {
%input = f32[1,2,1]{2,1,0} parameter(0) %input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1) %filter = f32[1,1,1]{2,1,0} parameter(1)
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} ROOT %convolution = f32[1,4,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=1}
} }
)"; )";
TF_EXPECT_OK(ParseAndReturnUnverifiedModule(original).status()); TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
} }
TEST_F(HloParserTest, InvalidDimLabels) { TEST_F(HloParserTest, InvalidDimLabels) {
@ -2127,7 +2161,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
} }
)"; )";
TF_EXPECT_OK(ParseAndReturnUnverifiedModule(original).status()); TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
} }
TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
@ -2154,7 +2188,7 @@ ENTRY %test_comma.v4 () -> f32[] {
} }
)"; )";
TF_EXPECT_OK(ParseAndReturnUnverifiedModule(original).status()); TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
} }
TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
@ -2184,7 +2218,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
})"; })";
auto module = ParseAndReturnUnverifiedModule(original); auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status()); TF_ASSERT_OK(module.status());
auto program_layout = module.ValueOrDie()->entry_computation_layout(); auto program_layout = module.ValueOrDie()->entry_computation_layout();
ASSERT_EQ(program_layout.parameter_count(), 1); ASSERT_EQ(program_layout.parameter_count(), 1);
@ -2207,7 +2241,7 @@ c1 {
c2 { c2 {
const2 = f32[1]{0} constant({67890}) const2 = f32[1]{0} constant({67890})
})"; })";
auto module = ParseAndReturnUnverifiedModule(original); auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status()); TF_ASSERT_OK(module.status());
EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
} }
@ -2218,7 +2252,7 @@ ENTRY consts {
first = f32[1]{0} constant({12345}) first = f32[1]{0} constant({12345})
last = f32[1]{0} constant({67890}) last = f32[1]{0} constant({67890})
})"; })";
auto module = ParseAndReturnUnverifiedModule(original); auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status()); TF_ASSERT_OK(module.status());
EXPECT_EQ( EXPECT_EQ(
module.ValueOrDie()->entry_computation()->root_instruction()->name(), module.ValueOrDie()->entry_computation()->root_instruction()->name(),
@ -2238,7 +2272,7 @@ ENTRY /*comment*/ c1 {
/* something else */ /* something else */
)"; )";
auto module = ParseAndReturnUnverifiedModule(original); auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status()); TF_ASSERT_OK(module.status());
} }
@ -2257,7 +2291,7 @@ d
*/ */
})"; })";
auto module = ParseAndReturnUnverifiedModule(original); auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status()); TF_ASSERT_OK(module.status());
} }
@ -2281,7 +2315,7 @@ ENTRY c1 {
// Foo bar // Foo bar
ROOT const1 = f32[1]{0} constant({12345}) // Something else ROOT const1 = f32[1]{0} constant({12345}) // Something else
})"; })";
auto module = ParseAndReturnUnverifiedModule(original); auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status()); TF_ASSERT_OK(module.status());
} }
@ -2289,7 +2323,7 @@ TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
const string original = const string original =
"HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo " "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
"bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}"; "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
auto module = ParseAndReturnUnverifiedModule(original); auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status()); TF_ASSERT_OK(module.status());
} }
@ -2297,7 +2331,7 @@ TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
const string original = const string original =
"HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo " "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
"bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}"; "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
auto module = ParseAndReturnUnverifiedModule(original); auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status()); TF_ASSERT_OK(module.status());
} }
@ -2375,8 +2409,7 @@ ENTRY ReduceR3ToR2 {
ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
} }
)"; )";
TF_ASSERT_OK_AND_ASSIGN(auto module, TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(original));
ParseAndReturnUnverifiedModule(original));
ASSERT_NE(module->entry_computation(), nullptr); ASSERT_NE(module->entry_computation(), nullptr);
EXPECT_THAT(module->entry_computation()->root_instruction(), EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Reduce())); GmockMatch(m::Reduce()));
@ -2649,8 +2682,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
} }
)"; )";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ParseAndReturnUnverifiedModule(text));
ASSERT_FALSE(module->has_schedule()); ASSERT_FALSE(module->has_schedule());
} }
@ -2667,8 +2699,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
} }
)"; )";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ParseAndReturnUnverifiedModule(text));
ASSERT_FALSE(module->has_schedule()); ASSERT_FALSE(module->has_schedule());
} }
@ -2685,8 +2716,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
} }
)"; )";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ParseAndReturnUnverifiedModule(text));
ASSERT_TRUE(module->has_schedule()); ASSERT_TRUE(module->has_schedule());
TF_ASSERT_OK(module->schedule().Verify()); TF_ASSERT_OK(module->schedule().Verify());
EXPECT_EQ(module->schedule().sequences().size(), 1); EXPECT_EQ(module->schedule().sequences().size(), 1);
@ -2714,8 +2744,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
} }
)"; )";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ParseAndReturnUnverifiedModule(text));
ASSERT_TRUE(module->has_schedule()); ASSERT_TRUE(module->has_schedule());
TF_ASSERT_OK(module->schedule().Verify()); TF_ASSERT_OK(module->schedule().Verify());
EXPECT_EQ(module->schedule().sequences().size(), 1); EXPECT_EQ(module->schedule().sequences().size(), 1);
@ -2767,8 +2796,7 @@ ENTRY entry {
ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0) ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0)
} }
)"; )";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ParseAndReturnUnverifiedModule(text));
} }
TEST_F(HloParserTest, ShapeMismatchInOperand) { TEST_F(HloParserTest, ShapeMismatchInOperand) {