Use VerifiedHloModule in tests that already have valid HLO.

PiperOrigin-RevId: 275825992
Change-Id: I4b2d6e5d565f763285bd1e9b16976a6a6db0354f
This commit is contained in:
Adrian Kuegel 2019-10-21 05:48:32 -07:00 committed by TensorFlower Gardener
parent 2bfc99464c
commit c5bb3d5baf
11 changed files with 159 additions and 149 deletions

View File

@ -532,10 +532,10 @@ tf_cc_test(
srcs = ["pattern_matcher_test.cc"],
deps = [
":hlo",
":hlo_parser",
":pattern_matcher",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
@ -588,7 +588,6 @@ tf_cc_test(
srcs = ["hlo_reachability_test.cc"],
deps = [
":hlo",
":hlo_parser",
":hlo_reachability",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@ -1424,7 +1423,6 @@ tf_cc_test(
":hlo_dce",
":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@ -2672,7 +2670,6 @@ tf_cc_test(
srcs = ["hlo_replication_analysis_test.cc"],
deps = [
":hlo",
":hlo_parser",
":hlo_replication_analysis",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@ -3913,10 +3910,14 @@ tf_cc_test(
srcs = ["tuple_util_test.cc"],
deps = [
":hlo_matchers",
":hlo_module_config",
":hlo_parser",
":tuple_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:verified_hlo_module",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
],
)
@ -3946,6 +3947,7 @@ tf_cc_test(
":while_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/algorithm:container",
],
@ -4004,7 +4006,6 @@ tf_cc_test(
srcs = ["while_loop_constant_sinking_test.cc"],
deps = [
":hlo_matchers",
":hlo_parser",
":while_loop_constant_sinking",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",

View File

@ -213,11 +213,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(text));
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ASSERT_FALSE(module->has_schedule());
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module_copy,
auto module_copy,
HloModule::CreateFromProto(module->ToProto(), module->config()));
ASSERT_FALSE(module_copy->has_schedule());
}
@ -235,11 +234,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(text));
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ASSERT_TRUE(module->has_schedule());
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module_copy,
auto module_copy,
HloModule::CreateFromProto(module->ToProto(), module->config()));
ASSERT_TRUE(module_copy->has_schedule());
TF_ASSERT_OK(module_copy->schedule().Verify());
@ -272,8 +270,7 @@ ENTRY ReduceR3ToR2.v3 {
ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(text));
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
// Perform various transformations on the graph:
//
@ -306,7 +303,7 @@ ENTRY ReduceR3ToR2.v3 {
// Serialize and deserialize and verify that the instruction and computations
// unique ids are the same.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module_copy,
auto module_copy,
HloModule::CreateFromProto(module->ToProto(), module->config()));
// The module IDs should *not* be the same because module ids must be globally
@ -366,8 +363,7 @@ TEST_F(HloModuleTest, VerifyReplaceComputationsWithSortOp) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(text));
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
// Create a replacement computation
HloComputation* new_comp;

View File

@ -18,7 +18,6 @@ limitations under the License.
#include <set>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@ -208,7 +207,7 @@ TEST_F(HloReachabilityTest, ChannelReachability) {
}
TEST_F(HloReachabilityTest, ReplaceInstructions) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test
ENTRY entry {

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
@ -57,8 +56,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{false, true});
@ -106,8 +105,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, false});
@ -158,8 +157,8 @@ ENTRY SimpleWhileLoop {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true});
@ -207,8 +206,8 @@ ENTRY WhileLoopParameterAliasingNonReplicatedOutput {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true});
@ -253,8 +252,8 @@ ENTRY WhileLoopDifferentCondition {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true});
@ -302,8 +301,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true, true, true, false, true, true});
@ -366,8 +365,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true, true, true, true, true});
@ -404,8 +403,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, false, true, true, true});
@ -430,8 +429,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool>{true, true, true, true, false});

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
@ -53,7 +52,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(module.get(), [](const BufferValue& buffer) {
@ -87,7 +86,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(module.get(), [](const BufferValue& buffer) {
@ -136,7 +135,7 @@ ENTRY main {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(module.get(), [](const BufferValue& buffer) {
@ -180,7 +179,7 @@ ENTRY main {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(module.get(), [](const BufferValue& buffer) {
@ -241,7 +240,7 @@ ENTRY %WhileLoop () -> s32[] {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(module.get(), [](const BufferValue& buffer) {
@ -310,7 +309,7 @@ ENTRY %WhileLoop () -> s32[] {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(module_str));
ParseAndReturnVerifiedModule(module_str));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(module.get(), [](const BufferValue& buffer) {

View File

@ -47,7 +47,7 @@ class InstructionFusionForTesting : public InstructionFusion {
};
TEST_F(InstructionFusionTest, FuseInstructions) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY entry_computation {
p0 = f32[4,3]{1,0} parameter(0)
@ -67,7 +67,7 @@ TEST_F(InstructionFusionTest, FuseInstructions) {
}
TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
fused_computation {
p1 = f32[4,3] parameter(0)
@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
}
TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY entry_computation {
p0 = f32[4,3]{1,0} parameter(0)
@ -196,7 +196,7 @@ static int Count(const HloModule& module, HloOpcode op) {
}
TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@ -433,7 +433,7 @@ TEST_F(InstructionFusionTest, AllowBinarySameValueOperandsDuplication) {
}
TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY Test {
p0 = f32[100] parameter(0)
@ -458,7 +458,7 @@ TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) {
}
TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY Test {
p0 = f32[100] parameter(0)
@ -484,7 +484,7 @@ TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) {
TEST_F(InstructionFusionTest,
WideningConvertsAreAlwaysDuplicableIntoConsumers) {
auto module = ParseAndReturnUnverifiedModule(R"(
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY Test {
p0 = f16[100] parameter(0)

View File

@ -14,19 +14,21 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
namespace m = match;
using PatternMatcherTest = HloTestBase;
TEST(PatternMatcherTest, AddOp) {
TEST_F(PatternMatcherTest, AddOp) {
constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
ENTRY %two_plus_two_computation () -> f32[] {
%two = f32[] constant(2)
@ -34,7 +36,7 @@ TEST(PatternMatcherTest, AddOp) {
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
const HloInstruction* matched_inst;
HloInstruction* matched_operand;
@ -66,7 +68,7 @@ TEST(PatternMatcherTest, AddOp) {
match::Multiply(&matched_inst, match::Op(), match::Op())));
}
TEST(PatternMatcherTest, ScalarShape) {
TEST_F(PatternMatcherTest, ScalarShape) {
auto scalar_shape = ShapeUtil::MakeShape(F32, {});
Shape* matched_shape;
EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
@ -81,7 +83,7 @@ TEST(PatternMatcherTest, ScalarShape) {
match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
}
TEST(PatternMatcherTest, DenseArrayShape) {
TEST_F(PatternMatcherTest, DenseArrayShape) {
auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
Shape* matched_shape;
EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
@ -104,7 +106,7 @@ TEST(PatternMatcherTest, DenseArrayShape) {
EXPECT_EQ(matched_layout, &array_shape.layout());
}
TEST(PatternMatcherTest, SparseArrayShape) {
TEST_F(PatternMatcherTest, SparseArrayShape) {
auto array_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {2, 3, 4}, 10);
Shape* matched_shape;
EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
@ -127,7 +129,7 @@ TEST(PatternMatcherTest, SparseArrayShape) {
EXPECT_EQ(matched_layout, &array_shape.layout());
}
TEST(PatternMatcherTest, TupleShape) {
TEST_F(PatternMatcherTest, TupleShape) {
auto tuple_shape = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {1, 2, 3}),
ShapeUtil::MakeShape(S32, {4, 5}),
@ -175,7 +177,7 @@ TEST(PatternMatcherTest, TupleShape) {
Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
}
TEST(PatternMatcherTest, FusionKind) {
TEST_F(PatternMatcherTest, FusionKind) {
constexpr char kModuleStr[] = R"(
HloModule test_module
@ -188,7 +190,7 @@ TEST(PatternMatcherTest, FusionKind) {
ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
})";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
EXPECT_TRUE(Match(
@ -199,7 +201,7 @@ TEST(PatternMatcherTest, FusionKind) {
HloInstruction::FusionKind::kLoop)));
}
TEST(PatternMatcherTest, GetTupleElement) {
TEST_F(PatternMatcherTest, GetTupleElement) {
constexpr char kModuleStr[] = R"(
HloModule test_module
@ -208,7 +210,7 @@ TEST(PatternMatcherTest, GetTupleElement) {
ROOT gte = f32[] get-tuple-element(p0), index=1
})";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
@ -218,11 +220,11 @@ TEST(PatternMatcherTest, GetTupleElement) {
EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
}
TEST(PatternMatcherTest, AnyOf) {
TEST_F(PatternMatcherTest, AnyOf) {
constexpr char kModuleStr[] = R"(
HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
EXPECT_TRUE(
@ -236,7 +238,7 @@ TEST(PatternMatcherTest, AnyOf) {
match::ConstantScalar(2))));
}
TEST(PatternMatcherTest, ConstantScalar) {
TEST_F(PatternMatcherTest, ConstantScalar) {
using match::ConstantEffectiveScalar;
using match::ConstantScalar;
using match::Op;
@ -253,7 +255,7 @@ TEST(PatternMatcherTest, ConstantScalar) {
ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
})";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
const HloInstruction* a = root->operand(0);
@ -308,7 +310,7 @@ TEST(PatternMatcherTest, ConstantScalar) {
EXPECT_EQ(instr, a);
}
TEST(PatternMatcherTest, MultiplyAnyOrder) {
TEST_F(PatternMatcherTest, MultiplyAnyOrder) {
using match::ConstantScalar;
using match::MultiplyAnyOrder;
@ -320,7 +322,7 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) {
ROOT multiply = f16[] multiply(lhs, rhs)
})";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
const HloInstruction* instr;
@ -339,7 +341,7 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) {
.IsNonConstant()));
}
TEST(PatternMatcherTest, AnyOfShortCircuit) {
TEST_F(PatternMatcherTest, AnyOfShortCircuit) {
using match::AnyOf;
using match::Multiply;
using match::Op;
@ -352,7 +354,7 @@ TEST(PatternMatcherTest, AnyOfShortCircuit) {
ROOT multiply = f16[] multiply(lhs, rhs)
})";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
{
@ -375,7 +377,7 @@ TEST(PatternMatcherTest, AnyOfShortCircuit) {
}
}
TEST(PatternMatcherTest, AllOf) {
TEST_F(PatternMatcherTest, AllOf) {
using match::AllOf;
using match::Broadcast;
using match::Constant;
@ -384,7 +386,7 @@ TEST(PatternMatcherTest, AllOf) {
constexpr char kModuleStr[] = R"(
HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
auto f16_scalar = ShapeUtil::MakeShape(F16, {});
@ -407,7 +409,7 @@ TEST(PatternMatcherTest, AllOf) {
Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
}
TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
TEST_F(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
using match::AllOf;
using match::Broadcast;
using match::Constant;
@ -419,7 +421,7 @@ TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
ROOT v = f16[] constant(42)
})";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
const HloInstruction* constant = nullptr;
@ -430,7 +432,7 @@ TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
EXPECT_NE(nullptr, constant);
}
TEST(PatternMatcherTest, TestNoCapture) {
TEST_F(PatternMatcherTest, TestNoCapture) {
using match::Constant;
constexpr char kModuleStr[] = R"(
@ -439,7 +441,7 @@ TEST(PatternMatcherTest, TestNoCapture) {
ROOT v = f16[] constant(42)
})";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
const HloInstruction* constant = nullptr;
@ -447,7 +449,7 @@ TEST(PatternMatcherTest, TestNoCapture) {
EXPECT_EQ(nullptr, constant);
}
TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
TEST_F(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
using match::Add;
using match::AddAnyOrder;
using match::AnyOf;
@ -461,7 +463,7 @@ TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
ROOT add = f16[] add(u, v)
})";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
const HloInstruction* addend0 = nullptr;
@ -477,7 +479,7 @@ TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
EXPECT_EQ(nullptr, addend2);
}
TEST(PatternMatcherTest, TestConcat) {
TEST_F(PatternMatcherTest, TestConcat) {
using match::Concatenate;
using match::ConstantScalar;
using match::Op;
@ -497,7 +499,7 @@ TEST(PatternMatcherTest, TestConcat) {
ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
ASSERT_TRUE(Match(
root,
@ -557,7 +559,7 @@ string Explanation(const Elem& elem, const Pattern& pattern) {
EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
} while (0)
TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
auto layout = LayoutUtil::MakeLayout({1, 2});
auto layout2 = LayoutUtil::MakeLayout({2, 2});
@ -577,7 +579,7 @@ TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
"Layout has format DENSE but expected SPARSE");
}
TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
constexpr char kModuleStr[] = R"(
HloModule test_module
@ -587,7 +589,7 @@ TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kModuleStr));
ParseAndReturnVerifiedModule(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
@ -600,7 +602,7 @@ TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
"out = f32[] custom-call(), custom_call_target=\"test_target\"");
}
TEST(PatternMatcherTest, ShapeDescribeToAndExplain) {
TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
auto layout = shape.layout();
@ -715,7 +717,7 @@ std::unique_ptr<HloInstruction> SetName(absl::string_view name,
return instr;
}
TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
std::unique_ptr<HloInstruction> iota =
SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
/*iota_dimension=*/0));
@ -810,7 +812,7 @@ TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
"in c = s32[] constant(0)"));
}
TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
EXPECT_DESC_AND_EXPLANATION(
SetName("a", HloInstruction::CreateBinary(
@ -866,7 +868,7 @@ TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
"in a = s32[] add(s32[] p, s32[] c)");
}
TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
EXPECT_DESC_AND_EXPLANATION(
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
@ -887,7 +889,7 @@ TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
" in c = s32[] constant(0)");
}
TEST(PatternMatcherTest, Parameter) {
TEST_F(PatternMatcherTest, Parameter) {
auto param =
HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
auto non_param =
@ -911,7 +913,7 @@ TEST(PatternMatcherTest, Parameter) {
"in p0 = f32[] parameter(0)");
}
TEST(PatternMatcherTest, OneUseAndOneUser) {
TEST_F(PatternMatcherTest, OneUseAndOneUser) {
auto param =
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
@ -966,7 +968,7 @@ TEST(PatternMatcherTest, OneUseAndOneUser) {
"in p0 = f32[] parameter(0)");
}
TEST(HloMatchersTest, Comparison) {
TEST_F(PatternMatcherTest, Comparison) {
auto shape = ShapeUtil::MakeShape(F32, {1});
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");

View File

@ -68,8 +68,8 @@ ENTRY entry_computation {
ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
FoldTranspose(module.get());
@ -90,8 +90,8 @@ ENTRY entry_computation {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TransposeFolding transpose_folding(
[](const HloInstruction& dot,
@ -118,7 +118,7 @@ ENTRY entry_computation {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
TransposeFolding transpose_folding(
@ -146,8 +146,8 @@ ENTRY entry_computation {
ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
FoldTranspose(module.get());
@ -204,8 +204,8 @@ ENTRY entry_computation {
ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
FoldTranspose(module.get());
const HloComputation* callee = module->GetComputationWithName("callee");

View File

@ -15,16 +15,22 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include <memory>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
namespace xla {
namespace {
namespace op = ::xla::testing::opcode_matchers;
StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
StatusOr<std::unique_ptr<VerifiedHloModule>> GetParsedModule(
HloComputation** entry_computation, HloInstruction** param0,
HloInstruction** param1) {
const char* const hlo_string = R"(
@ -36,8 +42,12 @@ ENTRY entry {
}
)";
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
auto module = absl::make_unique<VerifiedHloModule>(
"TupleUtilTest", HloModuleConfig(), /*verifier_layout_sensitive=*/true,
/*allow_mixed_precision_in_hlo_verifier=*/false,
ShapeUtil::ByteSizeOfElements);
TF_RETURN_IF_ERROR(ParseHloString(hlo_string, module.get()));
TF_RETURN_IF_ERROR(module->Verify());
*entry_computation = module->entry_computation();
*param0 = (*entry_computation)->parameter_instruction(0);
@ -51,8 +61,7 @@ TEST(TupleUtilTest, ExtractPrefix) {
HloComputation* entry_computation;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
GetParsedModule(&entry_computation, &param0, &param1));
auto module, GetParsedModule(&entry_computation, &param0, &param1));
HloInstruction* prefix = TupleUtil::ExtractPrefix(param0, 2);
@ -65,8 +74,7 @@ TEST(TupleUtilTest, AppendSuffix) {
HloComputation* entry_computation;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
GetParsedModule(&entry_computation, &param0, &param1));
auto module, GetParsedModule(&entry_computation, &param0, &param1));
HloInstruction* with_suffix =
TupleUtil::AppendSuffix(param0, {param1, param1});

View File

@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@ -25,8 +25,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
class WhileLoopConstantSinkingTest : public ::testing::Test {};
using WhileLoopConstantSinkingTest = HloTestBase;
TEST_F(WhileLoopConstantSinkingTest, SinkOneConstant) {
const char* const hlo_string = R"(
@ -54,8 +53,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@ -94,8 +93,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@ -135,8 +134,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@ -183,8 +182,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@ -225,8 +224,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@ -271,8 +270,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@ -311,8 +310,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@ -354,8 +353,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@ -405,8 +404,8 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
ASSERT_TRUE(changed);

View File

@ -15,10 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_util.h"
#include <memory>
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@ -26,10 +29,12 @@ namespace {
namespace op = ::xla::testing::opcode_matchers;
StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
HloComputation** entry_computation, HloInstruction** param0,
HloInstruction** param1, HloInstruction** param2) {
const char* const hlo_string = R"(
class WhileUtilTest : public HloTestBase {
protected:
StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
HloComputation** entry_computation, HloInstruction** param0,
HloInstruction** param1, HloInstruction** param2) {
const char* const hlo_string = R"(
HloModule ModuleWithWhile
while_body {
@ -50,23 +55,25 @@ ENTRY entry {
}
)";
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
// TODO(b/80488902): Use VerifiedHloModule here.
TF_ASSIGN_OR_RETURN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
*entry_computation = module->entry_computation();
*param0 = (*entry_computation)->parameter_instruction(0);
*param1 = (*entry_computation)->parameter_instruction(1);
*param2 = (*entry_computation)->parameter_instruction(2);
*entry_computation = module->entry_computation();
*param0 = (*entry_computation)->parameter_instruction(0);
*param1 = (*entry_computation)->parameter_instruction(1);
*param2 = (*entry_computation)->parameter_instruction(2);
return std::move(module);
}
return std::move(module);
}
};
TEST(WhileUtil, MakeZeroInstructionsLiveOp) {
TEST_F(WhileUtilTest, MakeZeroInstructionsLiveOp) {
HloInstruction *param0, *param1, *param2;
HloComputation* entry_computation;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
auto module,
GetParsedModule(&entry_computation, &param0, &param1, &param2));
HloInstruction* while_instr = entry_computation->root_instruction();
@ -92,12 +99,12 @@ TEST(WhileUtil, MakeZeroInstructionsLiveOp) {
op::GetTupleElement(param_reconstructed, 1)));
}
TEST(WhileUtilTest, MakeTwoInstructionsLive) {
TEST_F(WhileUtilTest, MakeTwoInstructionsLive) {
HloInstruction *param0, *param1, *param2;
HloComputation* entry_computation;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
auto module,
GetParsedModule(&entry_computation, &param0, &param1, &param2));
HloInstruction* while_instr = entry_computation->root_instruction();
@ -128,7 +135,7 @@ TEST(WhileUtilTest, MakeTwoInstructionsLive) {
op::GetTupleElement(op::Parameter(0), 3)));
}
TEST(WhileUtilTest, GetInvariantGTEsForWhileBody) {
TEST_F(WhileUtilTest, GetInvariantGTEsForWhileBody) {
const char* const hlo_string = R"(
HloModule ModuleWithWhile
@ -151,8 +158,8 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
HloComputation* while_body = module->GetComputationWithName("body");
@ -166,7 +173,7 @@ ENTRY main {
EXPECT_EQ((*gte_list.begin())->name(), "gte.0");
}
TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) {
TEST_F(WhileUtilTest, AlwaysRemovePreviousWhileBody) {
const char* const hlo_string = R"(
HloModule WhileWithSideEffects
@ -192,8 +199,8 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
HloComputation* main = module->GetComputationWithName("main");
HloInstruction* while_instr = main->root_instruction();