[XLA] Make pattern_matchers work with gMock.
This lets us unify the HLO pattern matchers and the HLO gmock matchers (in a later patch). Unifying these two APIs is useful because then we don't have to learn two APIs, and we don't have to implement features twice. This change: - Adds and tests the DescribeTo and MatchAndExplain APIs (this is the major change) - Uses these new gmock matchers in a few tests as a proof of concept. - Rewrites the is-constant-scalar API to use a true matcher rather than a std::function predicate matcher. This is necessary to get a user-friendly DescribeTo message rather than "I don't know what this std::function does." - Adds EffectiveScalarConstant helpers along with the old ScalarConstant helpers and then uses these within while_loop_simplifier. - Adds some missing simple op matchers: Tuple, Convolution, Pad, etc. - Adds a Parameter(n) matcher. - Adds Op().Is(), which matches a particular HloInstruction*, which is used in while_loop_simplifier. - Updates documentation to reflect new functions (both added here and added in earlier patches). - Tightens up the documentation. It was getting pretty long, and I made it longer. - Changes implementation of FooAnyOrder so that it returns an Op rather than an AnyOf. This lets you do AddAnyOrder(...).IsScalar(), whereas before this was a compile error. - Changes the implementation of FooAnyOrder so it uses a custom matcher rather than an AnyOf, in service of better DescribeTo messages. - Implements "and" folding, i.e. AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...> in the service of better DescribeTo messages. PiperOrigin-RevId: 223451504
This commit is contained in:
parent
af4417be82
commit
19a1dd5268
@ -408,9 +408,36 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pattern_matcher_gmock",
|
||||
testonly = 1,
|
||||
hdrs = ["pattern_matcher_gmock.h"],
|
||||
deps = [
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "pattern_matcher_gmock_test",
|
||||
srcs = ["pattern_matcher_gmock_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":pattern_matcher",
|
||||
":pattern_matcher_gmock",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
@ -2631,6 +2658,8 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":layout_assignment",
|
||||
":pattern_matcher",
|
||||
":pattern_matcher_gmock",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_layout",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -2775,6 +2804,8 @@ tf_cc_test(
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":hlo_pass",
|
||||
":pattern_matcher",
|
||||
":pattern_matcher_gmock",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -3535,6 +3566,8 @@ tf_cc_test(
|
||||
":hlo_casting_utils",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":pattern_matcher",
|
||||
":pattern_matcher_gmock",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -22,21 +22,22 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace m = xla::match;
|
||||
|
||||
using HloConstantFoldingTest = HloTestBase;
|
||||
|
||||
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
|
||||
@ -49,13 +50,14 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
|
||||
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Constant());
|
||||
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
|
||||
EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<int64>(),
|
||||
42);
|
||||
}
|
||||
@ -70,13 +72,14 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
|
||||
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Constant());
|
||||
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
|
||||
EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
|
||||
42.0f);
|
||||
}
|
||||
@ -91,13 +94,14 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
|
||||
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(), op::Constant());
|
||||
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
|
||||
EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({0}), 42);
|
||||
EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({1}), 19);
|
||||
}
|
||||
@ -138,7 +142,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Constant());
|
||||
EXPECT_THAT(root, GmockMatch(m::Constant()));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
|
||||
}
|
||||
}
|
||||
@ -165,7 +169,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Constant());
|
||||
EXPECT_THAT(root, GmockMatch(m::Constant()));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
|
||||
}
|
||||
|
||||
@ -190,7 +194,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Constant());
|
||||
EXPECT_THAT(root, GmockMatch(m::Constant()));
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape));
|
||||
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
|
||||
@ -240,7 +244,8 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
|
||||
EXPECT_FALSE(result);
|
||||
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reduce());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Reduce()));
|
||||
}
|
||||
|
||||
const char* const kConstantFoldLargePad = R"(
|
||||
@ -260,7 +265,7 @@ TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) {
|
||||
EXPECT_FALSE(result);
|
||||
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Pad(op::Constant(), op::Constant()));
|
||||
GmockMatch(m::Pad(m::Constant(), m::Constant())));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -21,7 +21,8 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -29,7 +30,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = ::xla::testing::opcode_matchers;
|
||||
namespace m = ::xla::match;
|
||||
using absl::string_view;
|
||||
|
||||
struct TestData {
|
||||
@ -1893,7 +1894,8 @@ ENTRY ReduceR3ToR2 {
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original));
|
||||
ASSERT_NE(module->entry_computation(), nullptr);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Reduce()));
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseSharding) {
|
||||
@ -1953,7 +1955,7 @@ TEST(HloParserSingleOpTest, SingleOp) {
|
||||
const HloComputation* computation = module->entry_computation();
|
||||
ASSERT_NE(computation, nullptr);
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Multiply(op::Parameter(0), op::Parameter(1)));
|
||||
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
|
||||
}
|
||||
|
||||
TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
|
||||
@ -1981,7 +1983,7 @@ TEST(HloParserSingleOpTest, SingleOpNoNames) {
|
||||
const HloComputation* computation = module->entry_computation();
|
||||
ASSERT_NE(computation, nullptr);
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Multiply(op::Parameter(0), op::Parameter(1)));
|
||||
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
|
||||
}
|
||||
|
||||
TEST(HloParserSingleOpTest, CanonicalOp) {
|
||||
@ -1990,7 +1992,7 @@ TEST(HloParserSingleOpTest, CanonicalOp) {
|
||||
const HloComputation* computation = module->entry_computation();
|
||||
ASSERT_NE(computation, nullptr);
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Multiply(op::Parameter(0), op::Parameter(1)));
|
||||
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
|
||||
EXPECT_EQ(
|
||||
computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
|
||||
text);
|
||||
@ -2044,7 +2046,11 @@ TEST(HloParserSingleOpTest, SingleOpWithNested) {
|
||||
const HloComputation* computation = module->entry_computation();
|
||||
ASSERT_NE(computation, nullptr);
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Fusion(op::Parameter(0), op::Parameter(1)));
|
||||
GmockMatch(m::Op()
|
||||
.WithOpcode(HloOpcode::kFusion)
|
||||
.WithNumOperands(2)
|
||||
.WithOperand(0, m::Parameter(0))
|
||||
.WithOperand(1, m::Parameter(1))));
|
||||
}
|
||||
|
||||
TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
|
||||
@ -2088,7 +2094,7 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
|
||||
const HloComputation* computation = module->entry_computation();
|
||||
ASSERT_NE(computation, nullptr);
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Convolution(op::Parameter(0), op::Parameter(1)));
|
||||
GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
|
||||
auto* convolution =
|
||||
Cast<HloConvolutionInstruction>(computation->root_instruction());
|
||||
EXPECT_EQ(convolution->feature_group_count(), 1);
|
||||
@ -2152,8 +2158,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
||||
module->schedule().is_computation_scheduled(module->entry_computation()));
|
||||
EXPECT_THAT(
|
||||
module->schedule().sequence(module->entry_computation()).instructions(),
|
||||
::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(),
|
||||
op::Multiply(), op::Parameter(), op::Add()));
|
||||
::testing::ElementsAre(
|
||||
GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
|
||||
GmockMatch(m::Parameter()), GmockMatch(m::Multiply()),
|
||||
GmockMatch(m::Parameter()), GmockMatch(m::Add())));
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
|
||||
@ -2179,8 +2187,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
||||
module->schedule().is_computation_scheduled(module->entry_computation()));
|
||||
EXPECT_THAT(
|
||||
module->schedule().sequence(module->entry_computation()).instructions(),
|
||||
::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
|
||||
op::Broadcast(), op::Multiply(), op::Add()));
|
||||
::testing::ElementsAre(
|
||||
GmockMatch(m::Parameter()), GmockMatch(m::Parameter()),
|
||||
GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
|
||||
GmockMatch(m::Multiply()), GmockMatch(m::Add())));
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
|
||||
|
@ -31,6 +31,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
|
||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
@ -42,11 +44,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace m = xla::match;
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
class LayoutAssignmentTest : public HloTestBase {
|
||||
@ -342,7 +343,8 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
|
||||
|
||||
// Verify the structure of the HLO graph.
|
||||
EXPECT_THAT(root,
|
||||
op::Tuple(op::Tuple(constant), op::Tuple(op::Copy(constant))));
|
||||
GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
|
||||
m::Tuple(m::Copy(m::Op().Is(constant))))));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
|
||||
@ -946,9 +948,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
|
||||
HloInstruction* root =
|
||||
compiled_module->entry_computation()->root_instruction();
|
||||
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
|
||||
EXPECT_THAT(root, op::Add(op::Parameter(),
|
||||
op::Slice(AllOf(op::Copy(op::Parameter(1)),
|
||||
op::ShapeWithLayout(shape_copy)))));
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
GmockMatch(m::Add(
|
||||
m::Parameter(),
|
||||
m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
|
||||
@ -976,10 +980,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
|
||||
compiled_module->entry_computation()->root_instruction();
|
||||
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
|
||||
EXPECT_THAT(root,
|
||||
op::Add(op::Parameter(),
|
||||
op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)),
|
||||
op::ShapeWithLayout(shape_copy)),
|
||||
op::Parameter(2))));
|
||||
GmockMatch(m::Add(
|
||||
m::Parameter(),
|
||||
m::DynamicSlice(
|
||||
m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
|
||||
m::Parameter(2)))));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
|
||||
@ -1007,11 +1012,12 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
|
||||
HloInstruction* root =
|
||||
compiled_module->entry_computation()->root_instruction();
|
||||
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
|
||||
EXPECT_THAT(root,
|
||||
op::Add(op::Parameter(),
|
||||
op::Concatenate(AllOf(op::Copy(op::Parameter(1)),
|
||||
op::ShapeWithLayout(shape_copy)),
|
||||
op::Parameter(2))));
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
GmockMatch(m::Add(
|
||||
m::Parameter(),
|
||||
m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
|
||||
m::Parameter(2)))));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest,
|
||||
@ -1038,7 +1044,8 @@ TEST_F(LayoutAssignmentTest,
|
||||
.ConsumeValueOrDie();
|
||||
HloInstruction* root =
|
||||
compiled_module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1)));
|
||||
EXPECT_THAT(root,
|
||||
GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
|
||||
@ -1062,8 +1069,9 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
|
||||
HloInstruction* root =
|
||||
compiled_module->entry_computation()->root_instruction();
|
||||
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
|
||||
EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)),
|
||||
op::ShapeWithLayout(shape_copy))));
|
||||
EXPECT_THAT(root,
|
||||
GmockMatch(m::Slice(
|
||||
m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
|
||||
@ -1149,7 +1157,7 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
|
||||
AssignLayouts(m.get(), &computation_layout);
|
||||
|
||||
HloInstruction* root = m->entry_computation()->root_instruction();
|
||||
ASSERT_THAT(root, op::CustomCall(op::Parameter()));
|
||||
ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
|
||||
ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
|
||||
ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
|
||||
}
|
||||
@ -1165,7 +1173,7 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
|
||||
AssignLayouts(m.get(), &computation_layout);
|
||||
|
||||
HloInstruction* root = m->entry_computation()->root_instruction();
|
||||
ASSERT_THAT(root, op::CustomCall(op::Parameter()));
|
||||
ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
|
||||
ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
|
||||
ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
|
||||
}
|
||||
@ -1196,7 +1204,7 @@ ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3
|
||||
// The custom call should be partially encapsulated in kCopy instructions
|
||||
// because of the layout mismatches.
|
||||
ASSERT_THAT(m->entry_computation()->root_instruction(),
|
||||
op::Copy(op::CustomCall(op::Copy(), op::Parameter())));
|
||||
GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
|
||||
|
||||
const HloInstruction* custom_call =
|
||||
m->entry_computation()->root_instruction()->operand(0);
|
||||
@ -1222,7 +1230,7 @@ ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
|
||||
AssignLayouts(m.get(), &computation_layout);
|
||||
|
||||
ASSERT_THAT(m->entry_computation()->root_instruction(),
|
||||
op::Copy(op::CustomCall()));
|
||||
GmockMatch(m::Copy(m::CustomCall())));
|
||||
|
||||
const HloInstruction* custom_call =
|
||||
m->entry_computation()->root_instruction()->operand(0);
|
||||
@ -1256,7 +1264,7 @@ ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f
|
||||
ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
|
||||
|
||||
ASSERT_THAT(m->entry_computation()->root_instruction(),
|
||||
op::Copy(op::CustomCall(op::Tuple())));
|
||||
GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
|
||||
|
||||
const HloInstruction* custom_call =
|
||||
m->entry_computation()->root_instruction()->operand(0);
|
||||
|
File diff suppressed because it is too large
Load Diff
92
tensorflow/compiler/xla/service/pattern_matcher_gmock.h
Normal file
92
tensorflow/compiler/xla/service/pattern_matcher_gmock.h
Normal file
@ -0,0 +1,92 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_
|
||||
|
||||
#include <ostream>
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace pattern_matcher_gmock_detail {
|
||||
template <typename Pattern>
|
||||
class GmockMatcher {
|
||||
public:
|
||||
explicit GmockMatcher(Pattern p) : pattern_(std::move(p)) {}
|
||||
|
||||
// In service of better error messages, list out the overloads explicitly
|
||||
// rather than just using a template. gMock's polymorphism plus
|
||||
// pattern_matcher yields some pretty gnarly stuff.
|
||||
bool MatchAndExplain(const Layout& l,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
return MatchAndExplainImpl(&l, listener);
|
||||
}
|
||||
bool MatchAndExplain(const Layout* l,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
return MatchAndExplainImpl(l, listener);
|
||||
}
|
||||
|
||||
bool MatchAndExplain(const Shape& s,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
return MatchAndExplainImpl(&s, listener);
|
||||
}
|
||||
bool MatchAndExplain(const Shape* s,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
return MatchAndExplainImpl(s, listener);
|
||||
}
|
||||
|
||||
bool MatchAndExplain(const HloInstruction& instr,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
return MatchAndExplainImpl(&instr, listener);
|
||||
}
|
||||
bool MatchAndExplain(const HloInstruction* instr,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
return MatchAndExplainImpl(instr, listener);
|
||||
}
|
||||
|
||||
void DescribeTo(std::ostream* os) const { pattern_.DescribeTo(os); }
|
||||
|
||||
void DescribeNegationTo(std::ostream* os) const {
|
||||
*os << "is NOT: ";
|
||||
DescribeTo(os);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool MatchAndExplainImpl(const T* t,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
MatchOption options{/*.capture=*/true, /*.explain_os=*/listener->stream()};
|
||||
return Match(t, pattern_, options);
|
||||
}
|
||||
|
||||
Pattern pattern_;
|
||||
};
|
||||
} // namespace pattern_matcher_gmock_detail
|
||||
|
||||
template <typename Pattern>
|
||||
::testing::PolymorphicMatcher<
|
||||
pattern_matcher_gmock_detail::GmockMatcher<Pattern>>
|
||||
GmockMatch(Pattern&& p) {
|
||||
return ::testing::MakePolymorphicMatcher(
|
||||
pattern_matcher_gmock_detail::GmockMatcher<Pattern>(
|
||||
std::forward<Pattern>(p)));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace m = ::xla::match;
|
||||
using ::testing::Eq;
|
||||
using ::testing::Not;
|
||||
|
||||
template <typename MatchedTy>
|
||||
string Describe(const ::testing::Matcher<MatchedTy>& m) {
|
||||
std::stringstream ss;
|
||||
m.DescribeTo(&ss);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename MatchedTy>
|
||||
string Explain(
|
||||
const MatchedTy& val,
|
||||
const ::testing::Matcher<typename std::remove_cv<MatchedTy>::type>& m) {
|
||||
::testing::StringMatchResultListener listener;
|
||||
EXPECT_THAT(val, ::testing::Not(m)); // For the error message.
|
||||
EXPECT_FALSE(m.MatchAndExplain(val, &listener));
|
||||
return listener.str();
|
||||
}
|
||||
|
||||
// This file tests the GmockMatch function. The actual explanation and
|
||||
// description returned by matchers is tested in pattern_matchers_test.
|
||||
TEST(PatternMatcherGmock, MatchShape) {
|
||||
Shape s = ShapeUtil::MakeShape(F32, {10, 100});
|
||||
// You can pass const Shape& or a const Shape*.
|
||||
EXPECT_THAT(s, GmockMatch(m::Shape()));
|
||||
EXPECT_THAT(&s, Not(GmockMatch(m::Shape().WithElementType(F16))));
|
||||
EXPECT_THAT(Describe<Shape>(GmockMatch(m::Shape().IsArray())),
|
||||
"a shape that represents an array");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherGmock, MatchLayout) {
|
||||
Layout l = LayoutUtil::MakeLayout({0, 1});
|
||||
EXPECT_THAT(l, GmockMatch(m::Layout()));
|
||||
EXPECT_THAT(&l, Not(GmockMatch(m::Layout().WithSparseFormat())));
|
||||
EXPECT_THAT(Describe<Layout>(GmockMatch(m::Layout().WithSparseFormat())),
|
||||
"a layout with format SPARSE");
|
||||
}
|
||||
|
||||
TEST(PatternMatchGmock, MatchInstruction) {
|
||||
auto instr =
|
||||
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {42}), "p");
|
||||
EXPECT_THAT(instr.get(), GmockMatch(m::Parameter()));
|
||||
EXPECT_THAT(*instr, GmockMatch(m::Parameter(0)));
|
||||
EXPECT_THAT(*instr, Not(GmockMatch(m::Parameter(1))));
|
||||
EXPECT_THAT(Describe<HloInstruction*>(GmockMatch(m::Parameter())),
|
||||
"an HloInstruction with opcode parameter");
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace xla
|
@ -14,14 +14,18 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace m = match;
|
||||
|
||||
TEST(PatternMatcherTest, AddOp) {
|
||||
constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
|
||||
ENTRY %two_plus_two_computation () -> f32[] {
|
||||
@ -229,23 +233,74 @@ TEST(PatternMatcherTest, AnyOf) {
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, ConstantScalar) {
|
||||
using match::ConstantEffectiveScalar;
|
||||
using match::ConstantScalar;
|
||||
using match::Op;
|
||||
using match::Tuple;
|
||||
|
||||
constexpr char kModuleStr[] = R"(
|
||||
HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })";
|
||||
HloModule test_module
|
||||
ENTRY test {
|
||||
a = s32[] constant(1)
|
||||
b = s32[1,1] constant(s32[1,1]{{2}})
|
||||
c = s32[1,2] constant(s32[1,2]{{2,2}})
|
||||
d = f32[] constant(1)
|
||||
e = f32[] constant(1.25)
|
||||
ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
|
||||
EXPECT_TRUE(Match(root, match::ConstantScalar(42)));
|
||||
EXPECT_FALSE(Match(root, match::ConstantScalar(41)));
|
||||
EXPECT_FALSE(Match(root, match::ConstantScalar(0)));
|
||||
}
|
||||
const HloInstruction* a = root->operand(0);
|
||||
const HloInstruction* b = root->operand(1);
|
||||
const HloInstruction* c = root->operand(2);
|
||||
const HloInstruction* d = root->operand(3);
|
||||
const HloInstruction* e = root->operand(4);
|
||||
EXPECT_TRUE(Match(a, ConstantScalar()));
|
||||
EXPECT_TRUE(Match(a, ConstantScalar(1)));
|
||||
EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
|
||||
EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
|
||||
EXPECT_FALSE(Match(a, ConstantScalar(2)));
|
||||
EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
|
||||
EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
|
||||
EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
|
||||
|
||||
TEST(PatternMatcherTest, NoMatchConstantScalar) {
|
||||
constexpr char kModuleStr[] = R"(
|
||||
HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
EXPECT_FALSE(Match(b, ConstantScalar()));
|
||||
EXPECT_FALSE(Match(b, ConstantScalar(2)));
|
||||
EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
|
||||
EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
|
||||
|
||||
EXPECT_FALSE(Match(root, match::ConstantScalar(42)));
|
||||
EXPECT_FALSE(Match(c, ConstantScalar()));
|
||||
EXPECT_FALSE(Match(c, ConstantScalar(2)));
|
||||
EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
|
||||
EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
|
||||
|
||||
EXPECT_TRUE(Match(d, ConstantScalar(1)));
|
||||
EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
|
||||
EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
|
||||
EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
|
||||
|
||||
EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
|
||||
EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
|
||||
EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
|
||||
EXPECT_FALSE(Match(e, ConstantScalar(1)));
|
||||
EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
|
||||
|
||||
const HloInstruction* instr = nullptr;
|
||||
EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
|
||||
EXPECT_EQ(instr, a);
|
||||
|
||||
instr = nullptr;
|
||||
EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
|
||||
EXPECT_EQ(instr, a);
|
||||
|
||||
instr = nullptr;
|
||||
EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
|
||||
EXPECT_EQ(instr, a);
|
||||
|
||||
instr = nullptr;
|
||||
EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
|
||||
EXPECT_EQ(instr, a);
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, MultiplyAnyOrder) {
|
||||
@ -267,6 +322,15 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) {
|
||||
root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
|
||||
EXPECT_TRUE(Match(
|
||||
root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
|
||||
|
||||
// Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
|
||||
// e.g. IsNonConstant() on it.
|
||||
EXPECT_TRUE(Match(
|
||||
root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
|
||||
.IsNonConstant()));
|
||||
EXPECT_TRUE(
|
||||
Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
|
||||
.IsNonConstant()));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, AnyOfShortCircuit) {
|
||||
@ -315,14 +379,22 @@ TEST(PatternMatcherTest, AllOf) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
|
||||
auto f16_scalar = ShapeUtil::MakeShape(F16, {});
|
||||
auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
|
||||
auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
|
||||
auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
|
||||
auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16));
|
||||
ASSERT_TRUE(Match(root, scalar_pattern));
|
||||
ASSERT_TRUE(Match(root, f16_pattern));
|
||||
EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern)));
|
||||
EXPECT_TRUE(Match(root, AllOf<HloInstruction>(f16_pattern, scalar_pattern)));
|
||||
ASSERT_TRUE(Match(root, f16_compatible_pattern));
|
||||
EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
|
||||
f16_compatible_pattern)));
|
||||
EXPECT_TRUE(
|
||||
Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
|
||||
scalar_pattern)));
|
||||
EXPECT_FALSE(
|
||||
Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
|
||||
EXPECT_FALSE(Match(
|
||||
root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
|
||||
EXPECT_FALSE(
|
||||
Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
|
||||
}
|
||||
@ -431,5 +503,377 @@ TEST(PatternMatcherTest, TestConcat) {
|
||||
Reshape(ConstantScalar(4)))));
|
||||
}
|
||||
|
||||
template <typename Pattern>
|
||||
string Description(const Pattern& pattern) {
|
||||
std::stringstream ss;
|
||||
pattern.DescribeTo(&ss);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename Elem, typename Pattern>
|
||||
string Explanation(Elem* elem, const Pattern& pattern) {
|
||||
std::stringstream ss;
|
||||
MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
|
||||
Match(elem, pattern, options);
|
||||
return ss.str();
|
||||
}
|
||||
template <typename Elem, typename Pattern>
|
||||
string Explanation(const std::unique_ptr<Elem>& elem, const Pattern& pattern) {
|
||||
return Explanation(elem.get(), pattern);
|
||||
}
|
||||
template <typename Elem, typename Pattern>
|
||||
string Explanation(const Elem& elem, const Pattern& pattern) {
|
||||
return Explanation(&elem, pattern);
|
||||
}
|
||||
|
||||
// Helper macro for checking a pattern's description and the explanation printed
|
||||
// when attempting to match (and presumably failing) on a given object.
|
||||
//
|
||||
// We use a macro rather than a function because we want good line numbers in
|
||||
// errors. We use this rather than writing a helper that returns a pair of
|
||||
// (description, explanation) and doing something like
|
||||
//
|
||||
// EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
|
||||
//
|
||||
// because EXPECT_EQ prints a unified diff if multiline string comparison fails,
|
||||
// while EXPECT_THAT does not. This unified diff makes the errors much easier
|
||||
// to read.
|
||||
#define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc, \
|
||||
expected_explanation) \
|
||||
do { \
|
||||
EXPECT_EQ(Description(pattern), (expected_desc)); \
|
||||
EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
|
||||
} while (0)
|
||||
|
||||
TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
|
||||
auto layout = LayoutUtil::MakeLayout({1, 2});
|
||||
auto layout2 = LayoutUtil::MakeLayout({2, 2});
|
||||
|
||||
EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
|
||||
"a layout", "Layout is null");
|
||||
EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
|
||||
"a layout equal to {1,2}",
|
||||
"Layout {2,2} is not equal to expected {1,2}");
|
||||
EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().WithSparseFormat(),
|
||||
"a layout with format SPARSE",
|
||||
"Layout has format DENSE but expected SPARSE");
|
||||
EXPECT_DESC_AND_EXPLANATION(layout,
|
||||
m::Layout().EqualTo(&layout).WithSparseFormat(),
|
||||
"a layout:\n"
|
||||
" * equal to {1,2} AND\n"
|
||||
" * with format SPARSE",
|
||||
"Layout has format DENSE but expected SPARSE");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, ShapeDescribeToAndExplain) {
|
||||
auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
|
||||
auto layout = shape.layout();
|
||||
|
||||
EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
|
||||
"a shape", "Shape is null");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
|
||||
m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
|
||||
"Shape not equal to f32[1,2]{0,1}\n"
|
||||
"in f32[1,2]{1,0}");
|
||||
EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
|
||||
m::Shape().CompatibleTo(&shape),
|
||||
"a shape compatible with f32[1,2]",
|
||||
"Shape not compatible with f32[1,2]\n"
|
||||
"in f32[2,2]{1,0}");
|
||||
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
|
||||
"a shape with element type F16",
|
||||
"Shape does not have element type F16\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
|
||||
"a shape that represents a scalar",
|
||||
"Shape is not a scalar\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
|
||||
"a shape that represents an array",
|
||||
"Shape is not an array\n"
|
||||
"in ()");
|
||||
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
|
||||
"a shape that represents a tuple",
|
||||
"Shape is not a tuple\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
|
||||
"a shape that is an effective scalar",
|
||||
"Shape is not an effective scalar\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
|
||||
"a shape that has 42 dimensions",
|
||||
"Shape does not have rank 42\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
|
||||
"a shape that is a scalar",
|
||||
"Shape is not a scalar\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
|
||||
"a shape:\n"
|
||||
" * that has 1 dimension AND\n"
|
||||
" * that represents an array",
|
||||
"Shape does not have rank 1\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
|
||||
m::Shape().IsArray().WithRank(1),
|
||||
"a shape:\n"
|
||||
" * that represents an array AND\n"
|
||||
" * that has 1 dimension",
|
||||
"Shape is not an array\n"
|
||||
"in ()");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
|
||||
m::Shape().WithLayoutEqualTo(&layout),
|
||||
"a shape with\n a layout equal to {0,1}",
|
||||
"Layout {1,0} is not equal to expected {0,1}\n"
|
||||
"in f32[1,2]{1,0}");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
shape, m::Shape().WithLayout(m::Layout().WithSparseFormat()),
|
||||
"a shape with\n a layout with format SPARSE",
|
||||
"Layout has format DENSE but expected SPARSE\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(shape,
|
||||
m::Shape().WithSubshapeEqualTo({10}, &shape),
|
||||
"a shape with subshape at index {10} which is\n"
|
||||
" a shape equal to f32[1,2]{0,1}",
|
||||
"No subshape at {10}\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
|
||||
m::Shape().WithSubshapeEqualTo({0}, &shape),
|
||||
"a shape with subshape at index {0} which is\n"
|
||||
" a shape equal to f32[1,2]{0,1}",
|
||||
"Shape not equal to f32[1,2]{0,1}\n"
|
||||
"in f32[2,2]{1,0}\n"
|
||||
"in subshape at {0}\n"
|
||||
"in (f32[2,2])");
|
||||
EXPECT_DESC_AND_EXPLANATION(shape,
|
||||
m::Shape().WithSubshapeCompatibleTo({10}, &shape),
|
||||
"a shape with subshape at index {10} which is\n"
|
||||
" a shape compatible with f32[1,2]",
|
||||
"No subshape at {10}\n"
|
||||
"in f32[1,2]{0,1}");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
|
||||
m::Shape().WithSubshapeCompatibleTo({0}, &shape),
|
||||
"a shape with subshape at index {0} which is\n"
|
||||
" a shape compatible with f32[1,2]",
|
||||
"Shape not compatible with f32[1,2]\n"
|
||||
"in f32[2,2]{1,0}\n"
|
||||
"in subshape at {0}\n"
|
||||
"in (f32[2,2])");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
|
||||
m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
|
||||
"a shape with subshape at index {0,0} which is\n"
|
||||
" a shape that represents a scalar",
|
||||
"Shape is not a scalar\n"
|
||||
"in f32[1,2]{0,1}\n"
|
||||
"in subshape at {0,0}\n"
|
||||
"in ((f32[1,2]))");
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction> SetName(absl::string_view name,
|
||||
std::unique_ptr<HloInstruction> instr) {
|
||||
instr->SetAndSanitizeName(string(name));
|
||||
return instr;
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
|
||||
std::unique_ptr<HloInstruction> iota =
|
||||
SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
|
||||
/*iota_dimension=*/0));
|
||||
std::unique_ptr<HloInstruction> constant =
|
||||
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
|
||||
|
||||
EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
|
||||
m::Op(), "an HloInstruction",
|
||||
"HloInstruction* is null");
|
||||
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
|
||||
"an HloInstruction named \"foo\"",
|
||||
"HloInstruction not named \"foo\"\n"
|
||||
"in i = s32[42]{0} iota(), iota_dimension=0");
|
||||
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
|
||||
"an HloInstruction with opcode add",
|
||||
"HloInstruction doesn't have opcode add\n"
|
||||
"in i = s32[42]{0} iota(), iota_dimension=0");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
constant, m::Op().IsNonConstant(),
|
||||
"an HloInstruction with any opcode other than constant",
|
||||
"HloInstruction has opcode constant, expected anything else\n"
|
||||
"in c = s32[] constant(0)");
|
||||
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
|
||||
"an HloInstruction with 42 operands",
|
||||
"HloInstruction doesn't have 42 operands\n"
|
||||
"in i = s32[42]{0} iota(), iota_dimension=0");
|
||||
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
|
||||
"an HloInstruction outputting\n"
|
||||
" a shape that represents a tuple",
|
||||
"Shape is not a tuple\n"
|
||||
"in s32[42]{0}\n"
|
||||
"in output shape\n"
|
||||
"in i = s32[42]{0} iota(), iota_dimension=0");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
|
||||
"an HloInstruction with operand 2 which is:\n"
|
||||
" an HloInstruction with opcode add",
|
||||
"desired operand index 2 is out of bounds\n"
|
||||
"in i = s32[42]{0} iota(), iota_dimension=0");
|
||||
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
|
||||
HloOpcode::kAdd, constant.get(),
|
||||
constant.get())),
|
||||
m::Op().WithOperand(1, m::Op().IsNonConstant()),
|
||||
"an HloInstruction with operand 1 which is:\n"
|
||||
" an HloInstruction with any opcode other than constant",
|
||||
"HloInstruction has opcode constant, expected anything else\n"
|
||||
"in c = s32[] constant(0)\n"
|
||||
"in operand 1\n"
|
||||
"in a = s32[] add(s32[] c, s32[] c)");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
|
||||
"an HloInstruction with fusion kind kLoop",
|
||||
"HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
|
||||
"in i = s32[42]{0} iota(), iota_dimension=0");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
iota, m::Op().WithTupleIndex(42),
|
||||
"an HloInstruction which is a GTE with index 42",
|
||||
"HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
|
||||
"in i = s32[42]{0} iota(), iota_dimension=0");
|
||||
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
|
||||
"an HloInstruction which is a constant scalar",
|
||||
"HloInstruction is not a constant\n"
|
||||
"in i = s32[42]{0} iota(), iota_dimension=0");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
SetName("c", HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<int>({1, 2}))),
|
||||
m::Op().IsConstantEffectiveScalar(),
|
||||
"an HloInstruction which is a constant effective scalar",
|
||||
"HloInstruction is not an effective scalar\n"
|
||||
"in c = s32[2]{0} constant({1, 2})");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
|
||||
m::Op().IsConstantScalar(42),
|
||||
"an HloInstruction which is a constant scalar with value 42",
|
||||
"HloInstruction's constant value 10 did not match expected value 42\n"
|
||||
"in c = s32[] constant(10)");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
|
||||
m::Op().IsConstantEffectiveScalar(1.25),
|
||||
"an HloInstruction which is a constant effective scalar with value 1.25",
|
||||
"HloInstruction's constant value 2.25 did not match expected value 1.25\n"
|
||||
"in c = f64[] constant(2.25)");
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
constant, m::Op().Is(iota.get()),
|
||||
absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), " (",
|
||||
iota->ToShortString(), ")"),
|
||||
absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
|
||||
absl::Hex(iota.get()), " (", iota->ToShortString(), ")\n",
|
||||
"in c = s32[] constant(0)"));
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
SetName("a", HloInstruction::CreateBinary(
|
||||
scalar_s32, HloOpcode::kAdd,
|
||||
SetName("b", HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0(0)))
|
||||
.get(),
|
||||
SetName("c", HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0(0)))
|
||||
.get())),
|
||||
m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
|
||||
"an HloInstruction:\n"
|
||||
" * with opcode add AND\n"
|
||||
" * with two operands in either order:\n"
|
||||
" - an HloInstruction named \"b\"\n"
|
||||
" - an HloInstruction named \"bar\"",
|
||||
"HloInstruction's operands (ignoring order) did not match second "
|
||||
"matcher. Specifically,\n"
|
||||
" - an HloInstruction named \"bar\"\n"
|
||||
"does not match LHS:\n"
|
||||
" - HloInstruction not named \"bar\"\n"
|
||||
" in b = s32[] constant(0)\n"
|
||||
"does not match RHS:\n"
|
||||
" - HloInstruction not named \"bar\"\n"
|
||||
" in c = s32[] constant(0)\n"
|
||||
"in a = s32[] add(s32[] b, s32[] c)");
|
||||
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
SetName("a",
|
||||
HloInstruction::CreateBinary(
|
||||
scalar_s32, HloOpcode::kAdd,
|
||||
HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
|
||||
SetName("c", HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0(0)))
|
||||
.get())),
|
||||
m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
|
||||
"an HloInstruction:\n"
|
||||
" * with opcode add AND\n"
|
||||
" * with two operands in either order:\n"
|
||||
" - an HloInstruction which is a constant scalar\n"
|
||||
" - an HloInstruction with opcode constant",
|
||||
"HloInstruction's LHS operand did not match either of the two matchers. "
|
||||
"Specifically,\n"
|
||||
" - an HloInstruction which is a constant scalar\n"
|
||||
"does not match LHS:\n"
|
||||
" - HloInstruction is not a constant\n"
|
||||
" in p = s32[] parameter(0)\n"
|
||||
"and\n"
|
||||
" - an HloInstruction with opcode constant\n"
|
||||
"does not match LHS:\n"
|
||||
" - HloInstruction doesn't have opcode constant\n"
|
||||
" in p = s32[] parameter(0)\n"
|
||||
"in a = s32[] add(s32[] p, s32[] c)");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
|
||||
m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
|
||||
m::Op().WithName("bar")),
|
||||
"any of:\n"
|
||||
" - an HloInstruction named \"foo\" OR\n"
|
||||
" - an HloInstruction named \"bar\"",
|
||||
"None of the following matchers succeeded:\n"
|
||||
"Matcher #1\n"
|
||||
" - an HloInstruction named \"foo\"\n"
|
||||
"failed with\n"
|
||||
" - HloInstruction not named \"foo\"\n"
|
||||
" in c = s32[] constant(0)\n"
|
||||
"Matcher #2\n"
|
||||
" - an HloInstruction named \"bar\"\n"
|
||||
"failed with\n"
|
||||
" - HloInstruction not named \"bar\"\n"
|
||||
" in c = s32[] constant(0)");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, Parameter) {
|
||||
auto param =
|
||||
HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
|
||||
auto non_param =
|
||||
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
|
||||
EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
|
||||
EXPECT_TRUE(Match(param.get(), m::Parameter()));
|
||||
EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
|
||||
EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
|
||||
EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
|
||||
|
||||
EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
|
||||
"an HloInstruction:\n"
|
||||
" * with opcode parameter AND\n"
|
||||
" * which is parameter 1",
|
||||
"HloInstruction doesn't have opcode parameter\n"
|
||||
"in c = s32[] constant(0)");
|
||||
EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "p0"),
|
||||
m::Parameter(1)),
|
||||
"HloInstruction is not parameter 1\n"
|
||||
"in p0 = f32[] parameter(0)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -526,16 +526,14 @@ static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
|
||||
// performance by forcing us to copy constants.
|
||||
absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
|
||||
for (int i = 0; i < root_operands.size(); i++) {
|
||||
HloInstruction* instr = root_operands[i];
|
||||
if (instr->opcode() == HloOpcode::kGetTupleElement &&
|
||||
instr->tuple_index() == i && instr->operand(0) == while_body_param &&
|
||||
ShapeUtil::IsScalar(instr->shape())) {
|
||||
auto tuple_element = while_init->operand(i);
|
||||
if (tuple_element->IsConstant()) {
|
||||
VLOG(3) << "Found loop invariant tuple element " << i << " "
|
||||
<< tuple_element->ToString();
|
||||
index_to_constant[i] = tuple_element;
|
||||
}
|
||||
const HloInstruction* init_tuple_elem = nullptr;
|
||||
if (Match(root_operands[i],
|
||||
m::GetTupleElement(m::Op().Is(while_body_param), i)
|
||||
.WithShape(m::Shape().IsScalar())) &&
|
||||
Match(while_init->operand(i), m::Constant(&init_tuple_elem))) {
|
||||
VLOG(3) << "Found loop invariant tuple element " << i << " "
|
||||
<< init_tuple_elem->ToString();
|
||||
index_to_constant[i] = init_tuple_elem;
|
||||
}
|
||||
}
|
||||
|
||||
@ -793,16 +791,11 @@ static StatusOr<HloInstruction*> TryMergeInductionVariables(
|
||||
// Maps the tuple index of each induction variable to its constant increment.
|
||||
absl::flat_hash_map<int64, const HloConstantInstruction*> induction_vars;
|
||||
for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
|
||||
const auto& elem_shape = while_body_root->operand(i)->shape();
|
||||
if (!ShapeUtil::IsEffectiveScalar(elem_shape) ||
|
||||
elem_shape.element_type() != elem_ty) {
|
||||
continue;
|
||||
}
|
||||
|
||||
HloInstruction* constant;
|
||||
if (!Match(while_body_root->mutable_operand(i),
|
||||
m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i),
|
||||
m::Constant(&constant)))) {
|
||||
m::ConstantScalar(&constant))
|
||||
.WithShape(m::Shape().WithElementType(elem_ty)))) {
|
||||
continue;
|
||||
}
|
||||
if (!trip_counter && constant->literal().IsAll(1) &&
|
||||
|
Loading…
x
Reference in New Issue
Block a user