Use VerifiedHloModule in tests that already have valid HLO.
PiperOrigin-RevId: 275825992 Change-Id: I4b2d6e5d565f763285bd1e9b16976a6a6db0354f
This commit is contained in:
parent
2bfc99464c
commit
c5bb3d5baf
@ -532,10 +532,10 @@ tf_cc_test(
|
||||
srcs = ["pattern_matcher_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_parser",
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -588,7 +588,6 @@ tf_cc_test(
|
||||
srcs = ["hlo_reachability_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_parser",
|
||||
":hlo_reachability",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
@ -1424,7 +1423,6 @@ tf_cc_test(
|
||||
":hlo_dce",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_ordering",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -2672,7 +2670,6 @@ tf_cc_test(
|
||||
srcs = ["hlo_replication_analysis_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_parser",
|
||||
":hlo_replication_analysis",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -3913,10 +3910,14 @@ tf_cc_test(
|
||||
srcs = ["tuple_util_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":hlo_module_config",
|
||||
":hlo_parser",
|
||||
":tuple_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:verified_hlo_module",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
@ -3946,6 +3947,7 @@ tf_cc_test(
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
],
|
||||
@ -4004,7 +4006,6 @@ tf_cc_test(
|
||||
srcs = ["while_loop_constant_sinking_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":while_loop_constant_sinking",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
|
@ -213,11 +213,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
||||
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
|
||||
ASSERT_FALSE(module->has_schedule());
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module_copy,
|
||||
auto module_copy,
|
||||
HloModule::CreateFromProto(module->ToProto(), module->config()));
|
||||
ASSERT_FALSE(module_copy->has_schedule());
|
||||
}
|
||||
@ -235,11 +234,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
||||
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
|
||||
ASSERT_TRUE(module->has_schedule());
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module_copy,
|
||||
auto module_copy,
|
||||
HloModule::CreateFromProto(module->ToProto(), module->config()));
|
||||
ASSERT_TRUE(module_copy->has_schedule());
|
||||
TF_ASSERT_OK(module_copy->schedule().Verify());
|
||||
@ -272,8 +270,7 @@ ENTRY ReduceR3ToR2.v3 {
|
||||
ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
|
||||
|
||||
// Perform various transformations on the graph:
|
||||
//
|
||||
@ -306,7 +303,7 @@ ENTRY ReduceR3ToR2.v3 {
|
||||
// Serialize and deserialize and verify that the instruction and computations
|
||||
// unique ids are the same.
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module_copy,
|
||||
auto module_copy,
|
||||
HloModule::CreateFromProto(module->ToProto(), module->config()));
|
||||
|
||||
// The module IDs should *not* be the same because module ids must be globally
|
||||
@ -366,8 +363,7 @@ TEST_F(HloModuleTest, VerifyReplaceComputationsWithSortOp) {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
|
||||
|
||||
// Create a replacement computation
|
||||
HloComputation* new_comp;
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
@ -208,7 +207,7 @@ TEST_F(HloReachabilityTest, ChannelReachability) {
|
||||
}
|
||||
|
||||
TEST_F(HloReachabilityTest, ReplaceInstructions) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule test
|
||||
|
||||
ENTRY entry {
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -57,8 +56,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{false, true});
|
||||
@ -106,8 +105,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, false});
|
||||
@ -158,8 +157,8 @@ ENTRY SimpleWhileLoop {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true});
|
||||
@ -207,8 +206,8 @@ ENTRY WhileLoopParameterAliasingNonReplicatedOutput {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true});
|
||||
@ -253,8 +252,8 @@ ENTRY WhileLoopDifferentCondition {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true});
|
||||
@ -302,8 +301,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true, true, true, false, true, true});
|
||||
@ -366,8 +365,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true, true, true, true, true});
|
||||
@ -404,8 +403,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, false, true, true, true});
|
||||
@ -430,8 +429,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(
|
||||
absl::Span<const bool>{true, true, true, true, false});
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -53,7 +52,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
@ -87,7 +86,7 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
@ -136,7 +135,7 @@ ENTRY main {
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
@ -180,7 +179,7 @@ ENTRY main {
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
@ -241,7 +240,7 @@ ENTRY %WhileLoop () -> s32[] {
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
@ -310,7 +309,7 @@ ENTRY %WhileLoop () -> s32[] {
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(module_str));
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
HloSchedule schedule,
|
||||
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
||||
|
@ -47,7 +47,7 @@ class InstructionFusionForTesting : public InstructionFusion {
|
||||
};
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseInstructions) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule test_module
|
||||
ENTRY entry_computation {
|
||||
p0 = f32[4,3]{1,0} parameter(0)
|
||||
@ -67,7 +67,7 @@ TEST_F(InstructionFusionTest, FuseInstructions) {
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule test_module
|
||||
fused_computation {
|
||||
p1 = f32[4,3] parameter(0)
|
||||
@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule test_module
|
||||
ENTRY entry_computation {
|
||||
p0 = f32[4,3]{1,0} parameter(0)
|
||||
@ -196,7 +196,7 @@ static int Count(const HloModule& module, HloOpcode op) {
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule test_module
|
||||
ENTRY OutputFusion {
|
||||
p0 = f32[4,3]{1,0} parameter(0)
|
||||
@ -433,7 +433,7 @@ TEST_F(InstructionFusionTest, AllowBinarySameValueOperandsDuplication) {
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule test_module
|
||||
ENTRY Test {
|
||||
p0 = f32[100] parameter(0)
|
||||
@ -458,7 +458,7 @@ TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) {
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule test_module
|
||||
ENTRY Test {
|
||||
p0 = f32[100] parameter(0)
|
||||
@ -484,7 +484,7 @@ TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) {
|
||||
|
||||
TEST_F(InstructionFusionTest,
|
||||
WideningConvertsAreAlwaysDuplicableIntoConsumers) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule test_module
|
||||
ENTRY Test {
|
||||
p0 = f16[100] parameter(0)
|
||||
|
@ -14,19 +14,21 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace m = match;
|
||||
using PatternMatcherTest = HloTestBase;
|
||||
|
||||
TEST(PatternMatcherTest, AddOp) {
|
||||
TEST_F(PatternMatcherTest, AddOp) {
|
||||
constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
|
||||
ENTRY %two_plus_two_computation () -> f32[] {
|
||||
%two = f32[] constant(2)
|
||||
@ -34,7 +36,7 @@ TEST(PatternMatcherTest, AddOp) {
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
|
||||
const HloInstruction* matched_inst;
|
||||
HloInstruction* matched_operand;
|
||||
@ -66,7 +68,7 @@ TEST(PatternMatcherTest, AddOp) {
|
||||
match::Multiply(&matched_inst, match::Op(), match::Op())));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, ScalarShape) {
|
||||
TEST_F(PatternMatcherTest, ScalarShape) {
|
||||
auto scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
Shape* matched_shape;
|
||||
EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
|
||||
@ -81,7 +83,7 @@ TEST(PatternMatcherTest, ScalarShape) {
|
||||
match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, DenseArrayShape) {
|
||||
TEST_F(PatternMatcherTest, DenseArrayShape) {
|
||||
auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
|
||||
Shape* matched_shape;
|
||||
EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
|
||||
@ -104,7 +106,7 @@ TEST(PatternMatcherTest, DenseArrayShape) {
|
||||
EXPECT_EQ(matched_layout, &array_shape.layout());
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, SparseArrayShape) {
|
||||
TEST_F(PatternMatcherTest, SparseArrayShape) {
|
||||
auto array_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {2, 3, 4}, 10);
|
||||
Shape* matched_shape;
|
||||
EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
|
||||
@ -127,7 +129,7 @@ TEST(PatternMatcherTest, SparseArrayShape) {
|
||||
EXPECT_EQ(matched_layout, &array_shape.layout());
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, TupleShape) {
|
||||
TEST_F(PatternMatcherTest, TupleShape) {
|
||||
auto tuple_shape = ShapeUtil::MakeTupleShape({
|
||||
ShapeUtil::MakeShape(F32, {1, 2, 3}),
|
||||
ShapeUtil::MakeShape(S32, {4, 5}),
|
||||
@ -175,7 +177,7 @@ TEST(PatternMatcherTest, TupleShape) {
|
||||
Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, FusionKind) {
|
||||
TEST_F(PatternMatcherTest, FusionKind) {
|
||||
constexpr char kModuleStr[] = R"(
|
||||
HloModule test_module
|
||||
|
||||
@ -188,7 +190,7 @@ TEST(PatternMatcherTest, FusionKind) {
|
||||
ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
EXPECT_TRUE(Match(
|
||||
@ -199,7 +201,7 @@ TEST(PatternMatcherTest, FusionKind) {
|
||||
HloInstruction::FusionKind::kLoop)));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, GetTupleElement) {
|
||||
TEST_F(PatternMatcherTest, GetTupleElement) {
|
||||
constexpr char kModuleStr[] = R"(
|
||||
HloModule test_module
|
||||
|
||||
@ -208,7 +210,7 @@ TEST(PatternMatcherTest, GetTupleElement) {
|
||||
ROOT gte = f32[] get-tuple-element(p0), index=1
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
|
||||
@ -218,11 +220,11 @@ TEST(PatternMatcherTest, GetTupleElement) {
|
||||
EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, AnyOf) {
|
||||
TEST_F(PatternMatcherTest, AnyOf) {
|
||||
constexpr char kModuleStr[] = R"(
|
||||
HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
|
||||
EXPECT_TRUE(
|
||||
@ -236,7 +238,7 @@ TEST(PatternMatcherTest, AnyOf) {
|
||||
match::ConstantScalar(2))));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, ConstantScalar) {
|
||||
TEST_F(PatternMatcherTest, ConstantScalar) {
|
||||
using match::ConstantEffectiveScalar;
|
||||
using match::ConstantScalar;
|
||||
using match::Op;
|
||||
@ -253,7 +255,7 @@ TEST(PatternMatcherTest, ConstantScalar) {
|
||||
ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
|
||||
const HloInstruction* a = root->operand(0);
|
||||
@ -308,7 +310,7 @@ TEST(PatternMatcherTest, ConstantScalar) {
|
||||
EXPECT_EQ(instr, a);
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, MultiplyAnyOrder) {
|
||||
TEST_F(PatternMatcherTest, MultiplyAnyOrder) {
|
||||
using match::ConstantScalar;
|
||||
using match::MultiplyAnyOrder;
|
||||
|
||||
@ -320,7 +322,7 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) {
|
||||
ROOT multiply = f16[] multiply(lhs, rhs)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
const HloInstruction* instr;
|
||||
|
||||
@ -339,7 +341,7 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) {
|
||||
.IsNonConstant()));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, AnyOfShortCircuit) {
|
||||
TEST_F(PatternMatcherTest, AnyOfShortCircuit) {
|
||||
using match::AnyOf;
|
||||
using match::Multiply;
|
||||
using match::Op;
|
||||
@ -352,7 +354,7 @@ TEST(PatternMatcherTest, AnyOfShortCircuit) {
|
||||
ROOT multiply = f16[] multiply(lhs, rhs)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
|
||||
{
|
||||
@ -375,7 +377,7 @@ TEST(PatternMatcherTest, AnyOfShortCircuit) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, AllOf) {
|
||||
TEST_F(PatternMatcherTest, AllOf) {
|
||||
using match::AllOf;
|
||||
using match::Broadcast;
|
||||
using match::Constant;
|
||||
@ -384,7 +386,7 @@ TEST(PatternMatcherTest, AllOf) {
|
||||
constexpr char kModuleStr[] = R"(
|
||||
HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
|
||||
auto f16_scalar = ShapeUtil::MakeShape(F16, {});
|
||||
@ -407,7 +409,7 @@ TEST(PatternMatcherTest, AllOf) {
|
||||
Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
|
||||
TEST_F(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
|
||||
using match::AllOf;
|
||||
using match::Broadcast;
|
||||
using match::Constant;
|
||||
@ -419,7 +421,7 @@ TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
|
||||
ROOT v = f16[] constant(42)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
|
||||
const HloInstruction* constant = nullptr;
|
||||
@ -430,7 +432,7 @@ TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
|
||||
EXPECT_NE(nullptr, constant);
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, TestNoCapture) {
|
||||
TEST_F(PatternMatcherTest, TestNoCapture) {
|
||||
using match::Constant;
|
||||
|
||||
constexpr char kModuleStr[] = R"(
|
||||
@ -439,7 +441,7 @@ TEST(PatternMatcherTest, TestNoCapture) {
|
||||
ROOT v = f16[] constant(42)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
|
||||
const HloInstruction* constant = nullptr;
|
||||
@ -447,7 +449,7 @@ TEST(PatternMatcherTest, TestNoCapture) {
|
||||
EXPECT_EQ(nullptr, constant);
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
|
||||
TEST_F(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
|
||||
using match::Add;
|
||||
using match::AddAnyOrder;
|
||||
using match::AnyOf;
|
||||
@ -461,7 +463,7 @@ TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
|
||||
ROOT add = f16[] add(u, v)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
|
||||
const HloInstruction* addend0 = nullptr;
|
||||
@ -477,7 +479,7 @@ TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
|
||||
EXPECT_EQ(nullptr, addend2);
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, TestConcat) {
|
||||
TEST_F(PatternMatcherTest, TestConcat) {
|
||||
using match::Concatenate;
|
||||
using match::ConstantScalar;
|
||||
using match::Op;
|
||||
@ -497,7 +499,7 @@ TEST(PatternMatcherTest, TestConcat) {
|
||||
ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
ASSERT_TRUE(Match(
|
||||
root,
|
||||
@ -557,7 +559,7 @@ string Explanation(const Elem& elem, const Pattern& pattern) {
|
||||
EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
|
||||
} while (0)
|
||||
|
||||
TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
|
||||
TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
|
||||
auto layout = LayoutUtil::MakeLayout({1, 2});
|
||||
auto layout2 = LayoutUtil::MakeLayout({2, 2});
|
||||
|
||||
@ -577,7 +579,7 @@ TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
|
||||
"Layout has format DENSE but expected SPARSE");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
|
||||
TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
|
||||
constexpr char kModuleStr[] = R"(
|
||||
HloModule test_module
|
||||
|
||||
@ -587,7 +589,7 @@ TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
||||
ParseAndReturnUnverifiedModule(kModuleStr));
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
|
||||
@ -600,7 +602,7 @@ TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
|
||||
"out = f32[] custom-call(), custom_call_target=\"test_target\"");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, ShapeDescribeToAndExplain) {
|
||||
TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
|
||||
auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
|
||||
auto layout = shape.layout();
|
||||
|
||||
@ -715,7 +717,7 @@ std::unique_ptr<HloInstruction> SetName(absl::string_view name,
|
||||
return instr;
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
|
||||
TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
|
||||
std::unique_ptr<HloInstruction> iota =
|
||||
SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
|
||||
/*iota_dimension=*/0));
|
||||
@ -810,7 +812,7 @@ TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
|
||||
"in c = s32[] constant(0)"));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
|
||||
TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
SetName("a", HloInstruction::CreateBinary(
|
||||
@ -866,7 +868,7 @@ TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
|
||||
"in a = s32[] add(s32[] p, s32[] c)");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
|
||||
TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
|
||||
m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
|
||||
@ -887,7 +889,7 @@ TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
|
||||
" in c = s32[] constant(0)");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, Parameter) {
|
||||
TEST_F(PatternMatcherTest, Parameter) {
|
||||
auto param =
|
||||
HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
|
||||
auto non_param =
|
||||
@ -911,7 +913,7 @@ TEST(PatternMatcherTest, Parameter) {
|
||||
"in p0 = f32[] parameter(0)");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, OneUseAndOneUser) {
|
||||
TEST_F(PatternMatcherTest, OneUseAndOneUser) {
|
||||
auto param =
|
||||
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
|
||||
|
||||
@ -966,7 +968,7 @@ TEST(PatternMatcherTest, OneUseAndOneUser) {
|
||||
"in p0 = f32[] parameter(0)");
|
||||
}
|
||||
|
||||
TEST(HloMatchersTest, Comparison) {
|
||||
TEST_F(PatternMatcherTest, Comparison) {
|
||||
auto shape = ShapeUtil::MakeShape(F32, {1});
|
||||
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
|
||||
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
|
||||
|
@ -68,8 +68,8 @@ ENTRY entry_computation {
|
||||
ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
FoldTranspose(module.get());
|
||||
|
||||
@ -90,8 +90,8 @@ ENTRY entry_computation {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
TransposeFolding transpose_folding(
|
||||
[](const HloInstruction& dot,
|
||||
@ -118,7 +118,7 @@ ENTRY entry_computation {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
|
||||
TransposeFolding transpose_folding(
|
||||
@ -146,8 +146,8 @@ ENTRY entry_computation {
|
||||
ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
FoldTranspose(module.get());
|
||||
|
||||
@ -204,8 +204,8 @@ ENTRY entry_computation {
|
||||
ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
FoldTranspose(module.get());
|
||||
|
||||
const HloComputation* callee = module->GetComputationWithName("callee");
|
||||
|
@ -15,16 +15,22 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = ::xla::testing::opcode_matchers;
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>> GetParsedModule(
|
||||
HloComputation** entry_computation, HloInstruction** param0,
|
||||
HloInstruction** param1) {
|
||||
const char* const hlo_string = R"(
|
||||
@ -36,8 +42,12 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
auto module = absl::make_unique<VerifiedHloModule>(
|
||||
"TupleUtilTest", HloModuleConfig(), /*verifier_layout_sensitive=*/true,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/false,
|
||||
ShapeUtil::ByteSizeOfElements);
|
||||
TF_RETURN_IF_ERROR(ParseHloString(hlo_string, module.get()));
|
||||
TF_RETURN_IF_ERROR(module->Verify());
|
||||
|
||||
*entry_computation = module->entry_computation();
|
||||
*param0 = (*entry_computation)->parameter_instruction(0);
|
||||
@ -51,8 +61,7 @@ TEST(TupleUtilTest, ExtractPrefix) {
|
||||
HloComputation* entry_computation;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
GetParsedModule(&entry_computation, ¶m0, ¶m1));
|
||||
auto module, GetParsedModule(&entry_computation, ¶m0, ¶m1));
|
||||
|
||||
HloInstruction* prefix = TupleUtil::ExtractPrefix(param0, 2);
|
||||
|
||||
@ -65,8 +74,7 @@ TEST(TupleUtilTest, AppendSuffix) {
|
||||
HloComputation* entry_computation;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
GetParsedModule(&entry_computation, ¶m0, ¶m1));
|
||||
auto module, GetParsedModule(&entry_computation, ¶m0, ¶m1));
|
||||
|
||||
HloInstruction* with_suffix =
|
||||
TupleUtil::AppendSuffix(param0, {param1, param1});
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -25,8 +25,7 @@ namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
using ::testing::_;
|
||||
|
||||
class WhileLoopConstantSinkingTest : public ::testing::Test {};
|
||||
using WhileLoopConstantSinkingTest = HloTestBase;
|
||||
|
||||
TEST_F(WhileLoopConstantSinkingTest, SinkOneConstant) {
|
||||
const char* const hlo_string = R"(
|
||||
@ -54,8 +53,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
WhileLoopConstantSinking{}.Run(module.get()));
|
||||
@ -94,8 +93,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
WhileLoopConstantSinking{}.Run(module.get()));
|
||||
@ -135,8 +134,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
WhileLoopConstantSinking{}.Run(module.get()));
|
||||
@ -183,8 +182,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
WhileLoopConstantSinking{}.Run(module.get()));
|
||||
@ -225,8 +224,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
WhileLoopConstantSinking{}.Run(module.get()));
|
||||
@ -271,8 +270,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
WhileLoopConstantSinking{}.Run(module.get()));
|
||||
@ -311,8 +310,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
WhileLoopConstantSinking{}.Run(module.get()));
|
||||
@ -354,8 +353,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
WhileLoopConstantSinking{}.Run(module.get()));
|
||||
@ -405,8 +404,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
WhileLoopConstantSinking{}.Run(module.get()));
|
||||
ASSERT_TRUE(changed);
|
||||
|
@ -15,10 +15,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -26,7 +29,9 @@ namespace {
|
||||
|
||||
namespace op = ::xla::testing::opcode_matchers;
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
|
||||
class WhileUtilTest : public HloTestBase {
|
||||
protected:
|
||||
StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
|
||||
HloComputation** entry_computation, HloInstruction** param0,
|
||||
HloInstruction** param1, HloInstruction** param2) {
|
||||
const char* const hlo_string = R"(
|
||||
@ -50,7 +55,8 @@ ENTRY entry {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
// TODO(b/80488902): Use VerifiedHloModule here.
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
|
||||
*entry_computation = module->entry_computation();
|
||||
@ -59,14 +65,15 @@ ENTRY entry {
|
||||
*param2 = (*entry_computation)->parameter_instruction(2);
|
||||
|
||||
return std::move(module);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST(WhileUtil, MakeZeroInstructionsLiveOp) {
|
||||
TEST_F(WhileUtilTest, MakeZeroInstructionsLiveOp) {
|
||||
HloInstruction *param0, *param1, *param2;
|
||||
HloComputation* entry_computation;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
auto module,
|
||||
GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2));
|
||||
|
||||
HloInstruction* while_instr = entry_computation->root_instruction();
|
||||
@ -92,12 +99,12 @@ TEST(WhileUtil, MakeZeroInstructionsLiveOp) {
|
||||
op::GetTupleElement(param_reconstructed, 1)));
|
||||
}
|
||||
|
||||
TEST(WhileUtilTest, MakeTwoInstructionsLive) {
|
||||
TEST_F(WhileUtilTest, MakeTwoInstructionsLive) {
|
||||
HloInstruction *param0, *param1, *param2;
|
||||
HloComputation* entry_computation;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module,
|
||||
auto module,
|
||||
GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2));
|
||||
|
||||
HloInstruction* while_instr = entry_computation->root_instruction();
|
||||
@ -128,7 +135,7 @@ TEST(WhileUtilTest, MakeTwoInstructionsLive) {
|
||||
op::GetTupleElement(op::Parameter(0), 3)));
|
||||
}
|
||||
|
||||
TEST(WhileUtilTest, GetInvariantGTEsForWhileBody) {
|
||||
TEST_F(WhileUtilTest, GetInvariantGTEsForWhileBody) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule ModuleWithWhile
|
||||
|
||||
@ -151,8 +158,8 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
HloComputation* while_body = module->GetComputationWithName("body");
|
||||
|
||||
@ -166,7 +173,7 @@ ENTRY main {
|
||||
EXPECT_EQ((*gte_list.begin())->name(), "gte.0");
|
||||
}
|
||||
|
||||
TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) {
|
||||
TEST_F(WhileUtilTest, AlwaysRemovePreviousWhileBody) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule WhileWithSideEffects
|
||||
|
||||
@ -192,8 +199,8 @@ ENTRY main {
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
HloComputation* main = module->GetComputationWithName("main");
|
||||
HloInstruction* while_instr = main->root_instruction();
|
||||
|
Loading…
x
Reference in New Issue
Block a user