Sparse layouts are not supported on any of the backends. For backwards compatibility the fields stay in the protobuf, but parsing them is a no-op. PiperOrigin-RevId: 287924498 Change-Id: I8b1c1ec52e3a423015837bc10deee832921ba66c
975 lines
40 KiB
C++
975 lines
40 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
|
|
|
#include "absl/strings/str_cat.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
#include "tensorflow/compiler/xla/test.h"
|
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
|
#include "tensorflow/core/platform/test.h"
|
|
|
|
namespace xla {
|
|
namespace {
|
|
|
|
namespace m = match;
|
|
using PatternMatcherTest = HloTestBase;
|
|
|
|
TEST_F(PatternMatcherTest, AddOp) {
|
|
constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
|
|
ENTRY %two_plus_two_computation () -> f32[] {
|
|
%two = f32[] constant(2)
|
|
ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
|
|
}
|
|
)";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
|
|
const HloInstruction* matched_inst;
|
|
HloInstruction* matched_operand;
|
|
Shape* matched_shape;
|
|
Layout* matched_layout;
|
|
|
|
ASSERT_TRUE(Match(
|
|
hlo_module->entry_computation()->root_instruction(),
|
|
match::Op(&matched_inst)
|
|
.WithName("two_plus_two")
|
|
.WithOpcode(HloOpcode::kAdd)
|
|
.WithShape(
|
|
match::Shape(&matched_shape)
|
|
.WithLayout(match::Layout(&matched_layout).WithDenseFormat()))
|
|
.WithOperand(
|
|
0,
|
|
match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant))));
|
|
ASSERT_NE(matched_inst, nullptr);
|
|
EXPECT_EQ(matched_inst->name(), "two_plus_two");
|
|
EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd);
|
|
|
|
EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(),
|
|
match::Add(match::Constant(), match::Constant())));
|
|
|
|
EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
|
|
match::Op().WithName("bad_name")));
|
|
matched_inst = nullptr;
|
|
EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
|
|
match::Multiply(&matched_inst, match::Op(), match::Op())));
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, ScalarShape) {
|
|
auto scalar_shape = ShapeUtil::MakeShape(F32, {});
|
|
Shape* matched_shape;
|
|
EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
|
|
EXPECT_EQ(matched_shape, &scalar_shape);
|
|
EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray()));
|
|
EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray()));
|
|
EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple()));
|
|
EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32)));
|
|
EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0)));
|
|
EXPECT_FALSE(Match(
|
|
&scalar_shape,
|
|
match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, DenseArrayShape) {
|
|
auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
|
|
Shape* matched_shape;
|
|
EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
|
|
EXPECT_EQ(matched_shape, &array_shape);
|
|
EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
|
|
EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
|
|
EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
|
|
EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
|
|
EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
|
|
EXPECT_FALSE(
|
|
Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
|
|
Layout* matched_layout;
|
|
EXPECT_TRUE(Match(&array_shape,
|
|
match::Shape().WithLayout(
|
|
match::Layout(&matched_layout).WithDenseFormat())));
|
|
EXPECT_EQ(matched_layout, &array_shape.layout());
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, TupleShape) {
|
|
auto tuple_shape = ShapeUtil::MakeTupleShape({
|
|
ShapeUtil::MakeShape(F32, {1, 2, 3}),
|
|
ShapeUtil::MakeShape(S32, {4, 5}),
|
|
});
|
|
EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple()));
|
|
EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray()));
|
|
EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar()));
|
|
|
|
Shape* subshape;
|
|
ASSERT_TRUE(Match(
|
|
&tuple_shape,
|
|
match::Shape().WithSubshape(
|
|
{0}, match::Shape(&subshape).WithElementType(F32).WithRank(3))));
|
|
ASSERT_NE(subshape, nullptr);
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0})));
|
|
EXPECT_TRUE(Match(&tuple_shape,
|
|
match::Shape().WithSubshape(
|
|
{0}, match::Shape().EqualTo(
|
|
&ShapeUtil::GetSubshape(tuple_shape, {0})))));
|
|
EXPECT_FALSE(Match(&tuple_shape,
|
|
match::Shape().WithSubshape(
|
|
{0}, match::Shape().EqualTo(
|
|
&ShapeUtil::GetSubshape(tuple_shape, {1})))));
|
|
|
|
ASSERT_TRUE(Match(
|
|
&tuple_shape,
|
|
match::Shape().WithSubshape(
|
|
{1}, match::Shape(&subshape).WithElementType(S32).WithRank(2))));
|
|
ASSERT_NE(subshape, nullptr);
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1})));
|
|
EXPECT_TRUE(Match(&tuple_shape,
|
|
match::Shape().WithSubshape(
|
|
{1}, match::Shape().EqualTo(
|
|
&ShapeUtil::GetSubshape(tuple_shape, {1})))));
|
|
EXPECT_FALSE(Match(&tuple_shape,
|
|
match::Shape().WithSubshape(
|
|
{1}, match::Shape().EqualTo(
|
|
&ShapeUtil::GetSubshape(tuple_shape, {0})))));
|
|
|
|
EXPECT_FALSE(
|
|
Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape())));
|
|
EXPECT_FALSE(
|
|
Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, FusionKind) {
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module
|
|
|
|
fused_computation {
|
|
ROOT fp0 = f32[] parameter(0)
|
|
}
|
|
|
|
ENTRY while.v11 {
|
|
p0 = f32[] parameter(0)
|
|
ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
|
|
})";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
EXPECT_TRUE(Match(
|
|
root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop)));
|
|
EXPECT_FALSE(Match(
|
|
root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput)));
|
|
EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind(
|
|
HloInstruction::FusionKind::kLoop)));
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, GetTupleElement) {
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module
|
|
|
|
ENTRY while.v11 {
|
|
p0 = (f32[], f32[], f32[]) parameter(0)
|
|
ROOT gte = f32[] get-tuple-element(p0), index=1
|
|
})";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
|
|
EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
|
|
EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
|
|
EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
|
|
EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, AnyOf) {
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
|
|
EXPECT_TRUE(
|
|
Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
|
|
match::ConstantScalar(1))));
|
|
EXPECT_TRUE(
|
|
Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
|
|
match::ConstantScalar(0))));
|
|
EXPECT_FALSE(
|
|
Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
|
|
match::ConstantScalar(2))));
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, ConstantScalar) {
|
|
using match::ConstantEffectiveScalar;
|
|
using match::ConstantScalar;
|
|
using match::Op;
|
|
using match::Tuple;
|
|
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module
|
|
ENTRY test {
|
|
a = s32[] constant(1)
|
|
b = s32[1,1] constant({{2}})
|
|
c = s32[1,2] constant({{2,2}})
|
|
d = f32[] constant(1)
|
|
e = f32[] constant(1.25)
|
|
ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
|
|
})";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
|
|
const HloInstruction* a = root->operand(0);
|
|
const HloInstruction* b = root->operand(1);
|
|
const HloInstruction* c = root->operand(2);
|
|
const HloInstruction* d = root->operand(3);
|
|
const HloInstruction* e = root->operand(4);
|
|
EXPECT_TRUE(Match(a, ConstantScalar()));
|
|
EXPECT_TRUE(Match(a, ConstantScalar(1)));
|
|
EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
|
|
EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
|
|
EXPECT_FALSE(Match(a, ConstantScalar(2)));
|
|
EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
|
|
EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
|
|
EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
|
|
|
|
EXPECT_FALSE(Match(b, ConstantScalar()));
|
|
EXPECT_FALSE(Match(b, ConstantScalar(2)));
|
|
EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
|
|
EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
|
|
|
|
EXPECT_FALSE(Match(c, ConstantScalar()));
|
|
EXPECT_FALSE(Match(c, ConstantScalar(2)));
|
|
EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
|
|
EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
|
|
|
|
EXPECT_TRUE(Match(d, ConstantScalar(1)));
|
|
EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
|
|
EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
|
|
EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
|
|
|
|
EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
|
|
EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
|
|
EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
|
|
EXPECT_FALSE(Match(e, ConstantScalar(1)));
|
|
EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
|
|
|
|
const HloInstruction* instr = nullptr;
|
|
EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
|
|
EXPECT_EQ(instr, a);
|
|
|
|
instr = nullptr;
|
|
EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
|
|
EXPECT_EQ(instr, a);
|
|
|
|
instr = nullptr;
|
|
EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
|
|
EXPECT_EQ(instr, a);
|
|
|
|
instr = nullptr;
|
|
EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
|
|
EXPECT_EQ(instr, a);
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, MultiplyAnyOrder) {
|
|
using match::ConstantScalar;
|
|
using match::MultiplyAnyOrder;
|
|
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module
|
|
ENTRY test {
|
|
lhs = f16[] constant(42)
|
|
rhs = f16[] constant(52)
|
|
ROOT multiply = f16[] multiply(lhs, rhs)
|
|
})";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
const HloInstruction* instr;
|
|
|
|
EXPECT_TRUE(Match(
|
|
root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
|
|
EXPECT_TRUE(Match(
|
|
root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
|
|
|
|
// Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
|
|
// e.g. IsNonConstant() on it.
|
|
EXPECT_TRUE(Match(
|
|
root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
|
|
.IsNonConstant()));
|
|
EXPECT_TRUE(
|
|
Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
|
|
.IsNonConstant()));
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, AnyOfShortCircuit) {
|
|
using match::AnyOf;
|
|
using match::Multiply;
|
|
using match::Op;
|
|
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module
|
|
ENTRY test {
|
|
lhs = f16[] constant(42)
|
|
rhs = f16[] constant(52)
|
|
ROOT multiply = f16[] multiply(lhs, rhs)
|
|
})";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
|
|
{
|
|
const HloInstruction* mul = nullptr;
|
|
const HloInstruction* any = nullptr;
|
|
|
|
ASSERT_TRUE(Match(
|
|
root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
|
|
EXPECT_NE(nullptr, mul);
|
|
EXPECT_EQ(nullptr, any);
|
|
}
|
|
{
|
|
const HloInstruction* mul = nullptr;
|
|
const HloInstruction* any = nullptr;
|
|
|
|
ASSERT_TRUE(Match(
|
|
root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
|
|
EXPECT_NE(nullptr, any);
|
|
EXPECT_EQ(nullptr, mul);
|
|
}
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, AllOf) {
|
|
using match::AllOf;
|
|
using match::Broadcast;
|
|
using match::Constant;
|
|
using match::Op;
|
|
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
|
|
auto f16_scalar = ShapeUtil::MakeShape(F16, {});
|
|
auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
|
|
auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
|
|
auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
|
|
ASSERT_TRUE(Match(root, scalar_pattern));
|
|
ASSERT_TRUE(Match(root, f16_pattern));
|
|
ASSERT_TRUE(Match(root, f16_compatible_pattern));
|
|
EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
|
|
f16_compatible_pattern)));
|
|
EXPECT_TRUE(
|
|
Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
|
|
scalar_pattern)));
|
|
EXPECT_FALSE(
|
|
Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
|
|
EXPECT_FALSE(Match(
|
|
root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
|
|
EXPECT_FALSE(
|
|
Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
|
|
using match::AllOf;
|
|
using match::Broadcast;
|
|
using match::Constant;
|
|
using match::Op;
|
|
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module
|
|
ENTRY test {
|
|
ROOT v = f16[] constant(42)
|
|
})";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
|
|
const HloInstruction* constant = nullptr;
|
|
ASSERT_FALSE(
|
|
Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
|
|
EXPECT_EQ(nullptr, constant);
|
|
ASSERT_TRUE(Match(root, Constant(&constant)));
|
|
EXPECT_NE(nullptr, constant);
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, TestNoCapture) {
|
|
using match::Constant;
|
|
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module
|
|
ENTRY test {
|
|
ROOT v = f16[] constant(42)
|
|
})";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
|
|
const HloInstruction* constant = nullptr;
|
|
ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
|
|
EXPECT_EQ(nullptr, constant);
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
|
|
using match::Add;
|
|
using match::AddAnyOrder;
|
|
using match::AnyOf;
|
|
using match::Op;
|
|
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module
|
|
ENTRY test {
|
|
u = f16[] parameter(0)
|
|
v = f16[] parameter(1)
|
|
ROOT add = f16[] add(u, v)
|
|
})";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
|
|
const HloInstruction* addend0 = nullptr;
|
|
const HloInstruction* addend1 = nullptr;
|
|
const HloInstruction* addend2 = nullptr;
|
|
auto add2_pattern = Add(Op(&addend0), Op(&addend1));
|
|
auto add3_pattern = AnyOf<HloInstruction>(
|
|
AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
|
|
|
|
ASSERT_TRUE(Match(root, add3_pattern));
|
|
EXPECT_NE(nullptr, addend0);
|
|
EXPECT_NE(nullptr, addend1);
|
|
EXPECT_EQ(nullptr, addend2);
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, TestConcat) {
|
|
using match::Concatenate;
|
|
using match::ConstantScalar;
|
|
using match::Op;
|
|
using match::Reshape;
|
|
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module
|
|
ENTRY test {
|
|
c1 = u32[] constant(1)
|
|
c2 = u32[] constant(2)
|
|
c3 = u32[] constant(3)
|
|
c4 = u32[] constant(4)
|
|
r1 = u32[1] reshape(c1)
|
|
r2 = u32[1] reshape(c2)
|
|
r3 = u32[1] reshape(c3)
|
|
r4 = u32[1] reshape(c4)
|
|
ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
|
|
})";
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
ASSERT_TRUE(Match(
|
|
root,
|
|
Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
|
|
Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
|
|
ASSERT_FALSE(Match(
|
|
root,
|
|
Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(1)),
|
|
Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
|
|
ASSERT_FALSE(Match(
|
|
root, Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
|
|
Reshape(ConstantScalar(3)))));
|
|
ASSERT_FALSE(Match(
|
|
root, Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(3)),
|
|
Reshape(ConstantScalar(4)))));
|
|
}
|
|
|
|
template <typename Pattern>
|
|
string Description(const Pattern& pattern) {
|
|
std::stringstream ss;
|
|
pattern.DescribeTo(&ss);
|
|
return ss.str();
|
|
}
|
|
|
|
template <typename Elem, typename Pattern>
|
|
string Explanation(Elem* elem, const Pattern& pattern) {
|
|
std::stringstream ss;
|
|
MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
|
|
Match(elem, pattern, options);
|
|
return ss.str();
|
|
}
|
|
template <typename Elem, typename Pattern>
|
|
string Explanation(const std::unique_ptr<Elem>& elem, const Pattern& pattern) {
|
|
return Explanation(elem.get(), pattern);
|
|
}
|
|
template <typename Elem, typename Pattern>
|
|
string Explanation(const Elem& elem, const Pattern& pattern) {
|
|
return Explanation(&elem, pattern);
|
|
}
|
|
|
|
// Helper macro for checking a pattern's description and the explanation printed
|
|
// when attempting to match (and presumably failing) on a given object.
|
|
//
|
|
// We use a macro rather than a function because we want good line numbers in
|
|
// errors. We use this rather than writing a helper that returns a pair of
|
|
// (description, explanation) and doing something like
|
|
//
|
|
// EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
|
|
//
|
|
// because EXPECT_EQ prints a unified diff if multiline string comparison fails,
|
|
// while EXPECT_THAT does not. This unified diff makes the errors much easier
|
|
// to read.
|
|
#define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc, \
|
|
expected_explanation) \
|
|
do { \
|
|
EXPECT_EQ(Description(pattern), (expected_desc)); \
|
|
EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
|
|
} while (0)
|
|
|
|
TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
|
|
auto layout = LayoutUtil::MakeLayout({1, 2});
|
|
auto layout2 = LayoutUtil::MakeLayout({2, 2});
|
|
|
|
EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
|
|
"a layout", "Layout is null");
|
|
EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
|
|
"a layout equal to {1,2}",
|
|
"Layout {2,2} is not equal to expected {1,2}");
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
|
|
constexpr char kModuleStr[] = R"(
|
|
HloModule test_module
|
|
|
|
ENTRY test {
|
|
ROOT out = f32[] custom-call(), custom_call_target="test_target"
|
|
}
|
|
)";
|
|
|
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
|
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
|
EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
|
|
EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target")));
|
|
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
root, match::Op().WithCustomCallTarget("other_target"),
|
|
"an HloInstruction custom call with target 'other_target'",
|
|
"HloInstruction is not a custom call with a target 'other_target'\nin "
|
|
"out = f32[] custom-call(), custom_call_target=\"test_target\"");
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
|
|
auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
|
|
auto layout = shape.layout();
|
|
|
|
EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
|
|
"a shape", "Shape is null");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
|
|
m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
|
|
"Shape not equal to f32[1,2]{0,1}\n"
|
|
"in f32[1,2]{1,0}");
|
|
EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
|
|
m::Shape().CompatibleTo(&shape),
|
|
"a shape compatible with f32[1,2]",
|
|
"Shape not compatible with f32[1,2]\n"
|
|
"in f32[2,2]{1,0}");
|
|
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
|
|
"a shape with element type F16",
|
|
"Shape does not have element type F16\n"
|
|
"in f32[1,2]{0,1}");
|
|
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
|
|
"a shape that represents a scalar",
|
|
"Shape is not a scalar\n"
|
|
"in f32[1,2]{0,1}");
|
|
EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
|
|
"a shape that represents an array",
|
|
"Shape is not an array\n"
|
|
"in ()");
|
|
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
|
|
"a shape that represents a tuple",
|
|
"Shape is not a tuple\n"
|
|
"in f32[1,2]{0,1}");
|
|
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
|
|
"a shape that is an effective scalar",
|
|
"Shape is not an effective scalar\n"
|
|
"in f32[1,2]{0,1}");
|
|
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
|
|
"a shape that has 42 dimensions",
|
|
"Shape does not have rank 42\n"
|
|
"in f32[1,2]{0,1}");
|
|
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
|
|
"a shape that is a scalar",
|
|
"Shape is not a scalar\n"
|
|
"in f32[1,2]{0,1}");
|
|
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
|
|
"a shape:\n"
|
|
" * that has 1 dimension AND\n"
|
|
" * that represents an array",
|
|
"Shape does not have rank 1\n"
|
|
"in f32[1,2]{0,1}");
|
|
EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
|
|
m::Shape().IsArray().WithRank(1),
|
|
"a shape:\n"
|
|
" * that represents an array AND\n"
|
|
" * that has 1 dimension",
|
|
"Shape is not an array\n"
|
|
"in ()");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
|
|
m::Shape().WithLayoutEqualTo(&layout),
|
|
"a shape with\n a layout equal to {0,1}",
|
|
"Layout {1,0} is not equal to expected {0,1}\n"
|
|
"in f32[1,2]{1,0}");
|
|
EXPECT_DESC_AND_EXPLANATION(shape,
|
|
m::Shape().WithSubshapeEqualTo({10}, &shape),
|
|
"a shape with subshape at index {10} which is\n"
|
|
" a shape equal to f32[1,2]{0,1}",
|
|
"No subshape at {10}\n"
|
|
"in f32[1,2]{0,1}");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
|
|
m::Shape().WithSubshapeEqualTo({0}, &shape),
|
|
"a shape with subshape at index {0} which is\n"
|
|
" a shape equal to f32[1,2]{0,1}",
|
|
"Shape not equal to f32[1,2]{0,1}\n"
|
|
"in f32[2,2]{1,0}\n"
|
|
"in subshape at {0}\n"
|
|
"in (f32[2,2])");
|
|
EXPECT_DESC_AND_EXPLANATION(shape,
|
|
m::Shape().WithSubshapeCompatibleTo({10}, &shape),
|
|
"a shape with subshape at index {10} which is\n"
|
|
" a shape compatible with f32[1,2]",
|
|
"No subshape at {10}\n"
|
|
"in f32[1,2]{0,1}");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
|
|
m::Shape().WithSubshapeCompatibleTo({0}, &shape),
|
|
"a shape with subshape at index {0} which is\n"
|
|
" a shape compatible with f32[1,2]",
|
|
"Shape not compatible with f32[1,2]\n"
|
|
"in f32[2,2]{1,0}\n"
|
|
"in subshape at {0}\n"
|
|
"in (f32[2,2])");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
|
|
m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
|
|
"a shape with subshape at index {0,0} which is\n"
|
|
" a shape that represents a scalar",
|
|
"Shape is not a scalar\n"
|
|
"in f32[1,2]{0,1}\n"
|
|
"in subshape at {0,0}\n"
|
|
"in ((f32[1,2]))");
|
|
}
|
|
|
|
std::unique_ptr<HloInstruction> SetName(absl::string_view name,
|
|
std::unique_ptr<HloInstruction> instr) {
|
|
instr->SetAndSanitizeName(string(name));
|
|
return instr;
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
|
|
std::unique_ptr<HloInstruction> iota =
|
|
SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
|
|
/*iota_dimension=*/0));
|
|
std::unique_ptr<HloInstruction> constant =
|
|
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
|
|
|
|
EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
|
|
m::Op(), "an HloInstruction",
|
|
"HloInstruction* is null");
|
|
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
|
|
"an HloInstruction named \"foo\"",
|
|
"HloInstruction not named \"foo\"\n"
|
|
"in i = s32[42]{0} iota(), iota_dimension=0");
|
|
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
|
|
"an HloInstruction with opcode add",
|
|
"HloInstruction doesn't have opcode add\n"
|
|
"in i = s32[42]{0} iota(), iota_dimension=0");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
constant, m::Op().IsNonConstant(),
|
|
"an HloInstruction with any opcode other than constant",
|
|
"HloInstruction has opcode constant, expected anything else\n"
|
|
"in c = s32[] constant(0)");
|
|
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
|
|
"an HloInstruction with 42 operands",
|
|
"HloInstruction doesn't have 42 operands\n"
|
|
"in i = s32[42]{0} iota(), iota_dimension=0");
|
|
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
|
|
"an HloInstruction outputting\n"
|
|
" a shape that represents a tuple",
|
|
"Shape is not a tuple\n"
|
|
"in s32[42]{0}\n"
|
|
"in output shape\n"
|
|
"in i = s32[42]{0} iota(), iota_dimension=0");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
|
|
"an HloInstruction with operand 2 which is:\n"
|
|
" an HloInstruction with opcode add",
|
|
"desired operand index 2 is out of bounds\n"
|
|
"in i = s32[42]{0} iota(), iota_dimension=0");
|
|
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
|
|
HloOpcode::kAdd, constant.get(),
|
|
constant.get())),
|
|
m::Op().WithOperand(1, m::Op().IsNonConstant()),
|
|
"an HloInstruction with operand 1 which is:\n"
|
|
" an HloInstruction with any opcode other than constant",
|
|
"HloInstruction has opcode constant, expected anything else\n"
|
|
"in c = s32[] constant(0)\n"
|
|
"in operand 1\n"
|
|
"in a = s32[] add(s32[] c, s32[] c)");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
|
|
"an HloInstruction with fusion kind kLoop",
|
|
"HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
|
|
"in i = s32[42]{0} iota(), iota_dimension=0");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
iota, m::Op().WithTupleIndex(42),
|
|
"an HloInstruction which is a GTE with index 42",
|
|
"HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
|
|
"in i = s32[42]{0} iota(), iota_dimension=0");
|
|
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
|
|
"an HloInstruction which is a constant scalar",
|
|
"HloInstruction is not a constant\n"
|
|
"in i = s32[42]{0} iota(), iota_dimension=0");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
SetName("c", HloInstruction::CreateConstant(
|
|
LiteralUtil::CreateR1<int>({1, 2}))),
|
|
m::Op().IsConstantEffectiveScalar(),
|
|
"an HloInstruction which is a constant effective scalar",
|
|
"HloInstruction is not an effective scalar\n"
|
|
"in c = s32[2]{0} constant({1, 2})");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
|
|
m::Op().IsConstantScalar(42),
|
|
"an HloInstruction which is a constant scalar with value 42",
|
|
"HloInstruction's constant value 10 did not match expected value 42\n"
|
|
"in c = s32[] constant(10)");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
|
|
m::Op().IsConstantEffectiveScalar(1.25),
|
|
"an HloInstruction which is a constant effective scalar with value 1.25",
|
|
"HloInstruction's constant value 2.25 did not match expected value 1.25\n"
|
|
"in c = f64[] constant(2.25)");
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
constant, m::Op().Is(iota.get()),
|
|
absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()),
|
|
" (i = s32[42]{0} iota(), iota_dimension=0)"),
|
|
absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
|
|
absl::Hex(iota.get()),
|
|
" (i = s32[42]{0} iota(), iota_dimension=0)\n"
|
|
"in c = s32[] constant(0)"));
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
|
|
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
SetName("a", HloInstruction::CreateBinary(
|
|
scalar_s32, HloOpcode::kAdd,
|
|
SetName("b", HloInstruction::CreateConstant(
|
|
LiteralUtil::CreateR0(0)))
|
|
.get(),
|
|
SetName("c", HloInstruction::CreateConstant(
|
|
LiteralUtil::CreateR0(0)))
|
|
.get())),
|
|
m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
|
|
"an HloInstruction:\n"
|
|
" * with opcode add AND\n"
|
|
" * with two operands in either order:\n"
|
|
" - an HloInstruction named \"b\"\n"
|
|
" - an HloInstruction named \"bar\"",
|
|
"HloInstruction's operands (ignoring order) did not match second "
|
|
"matcher. Specifically,\n"
|
|
" - an HloInstruction named \"bar\"\n"
|
|
"does not match LHS:\n"
|
|
" - HloInstruction not named \"bar\"\n"
|
|
" in b = s32[] constant(0)\n"
|
|
"does not match RHS:\n"
|
|
" - HloInstruction not named \"bar\"\n"
|
|
" in c = s32[] constant(0)\n"
|
|
"in a = s32[] add(s32[] b, s32[] c)");
|
|
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
SetName("a",
|
|
HloInstruction::CreateBinary(
|
|
scalar_s32, HloOpcode::kAdd,
|
|
HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
|
|
SetName("c", HloInstruction::CreateConstant(
|
|
LiteralUtil::CreateR0(0)))
|
|
.get())),
|
|
m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
|
|
"an HloInstruction:\n"
|
|
" * with opcode add AND\n"
|
|
" * with two operands in either order:\n"
|
|
" - an HloInstruction which is a constant scalar\n"
|
|
" - an HloInstruction with opcode constant",
|
|
"HloInstruction's LHS operand did not match either of the two matchers. "
|
|
"Specifically,\n"
|
|
" - an HloInstruction which is a constant scalar\n"
|
|
"does not match LHS:\n"
|
|
" - HloInstruction is not a constant\n"
|
|
" in p = s32[] parameter(0)\n"
|
|
"and\n"
|
|
" - an HloInstruction with opcode constant\n"
|
|
"does not match LHS:\n"
|
|
" - HloInstruction doesn't have opcode constant\n"
|
|
" in p = s32[] parameter(0)\n"
|
|
"in a = s32[] add(s32[] p, s32[] c)");
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
|
|
m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
|
|
m::Op().WithName("bar")),
|
|
"any of:\n"
|
|
" - an HloInstruction named \"foo\" OR\n"
|
|
" - an HloInstruction named \"bar\"",
|
|
"None of the following matchers succeeded:\n"
|
|
"Matcher #1\n"
|
|
" - an HloInstruction named \"foo\"\n"
|
|
"failed with\n"
|
|
" - HloInstruction not named \"foo\"\n"
|
|
" in c = s32[] constant(0)\n"
|
|
"Matcher #2\n"
|
|
" - an HloInstruction named \"bar\"\n"
|
|
"failed with\n"
|
|
" - HloInstruction not named \"bar\"\n"
|
|
" in c = s32[] constant(0)");
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, Parameter) {
|
|
auto param =
|
|
HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
|
|
auto non_param =
|
|
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
|
|
EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
|
|
EXPECT_TRUE(Match(param.get(), m::Parameter()));
|
|
EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
|
|
EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
|
|
EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
|
|
|
|
EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
|
|
"an HloInstruction:\n"
|
|
" * with opcode parameter AND\n"
|
|
" * which is parameter 1",
|
|
"HloInstruction doesn't have opcode parameter\n"
|
|
"in c = s32[] constant(0)");
|
|
EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
|
|
0, ShapeUtil::MakeShape(F32, {}), "p0"),
|
|
m::Parameter(1)),
|
|
"HloInstruction is not parameter 1\n"
|
|
"in p0 = f32[] parameter(0)");
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, OneUseAndOneUser) {
|
|
auto param =
|
|
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
|
|
|
|
EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
param, m::Op().WithOneUse(),
|
|
"an HloInstruction which has exactly one use",
|
|
"HloInstruction has 0 users, but expected exactly one.\n"
|
|
"in p0 = f32[] parameter(0)");
|
|
|
|
EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
param, m::Op().WithOneUser(),
|
|
"an HloInstruction which has exactly one user (but possibly is used "
|
|
"multiple times by that instruction)",
|
|
"HloInstruction has 0 users, but expected exactly one.\n"
|
|
"in p0 = f32[] parameter(0)");
|
|
|
|
{
|
|
auto reshape =
|
|
SetName("r", HloInstruction::CreateReshape(
|
|
ShapeUtil::MakeShape(F32, {1}), param.get()));
|
|
EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse()));
|
|
EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
|
|
|
|
auto reshape1 =
|
|
SetName("r1", HloInstruction::CreateReshape(
|
|
ShapeUtil::MakeShape(F32, {1}), param.get()));
|
|
EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
|
|
EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
|
|
|
|
const char* kMultipleUserExplanation =
|
|
"HloInstruction has 2 users, but expected exactly one.\n"
|
|
"All users:\n"
|
|
" - r = f32[1]{0} reshape(f32[] p0)\n"
|
|
" - r1 = f32[1]{0} reshape(f32[] p0)\n"
|
|
"in p0 = f32[] parameter(0)";
|
|
EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
|
|
kMultipleUserExplanation);
|
|
EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()),
|
|
kMultipleUserExplanation);
|
|
}
|
|
|
|
auto add = SetName("add", HloInstruction::CreateBinary(
|
|
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd,
|
|
param.get(), param.get()));
|
|
EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
|
|
EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
|
|
EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
|
|
"HloInstruction is used 2 times by its user, but is expected to be "
|
|
"used just once: add = f32[] add(f32[] p0, f32[] p0)\n"
|
|
"in p0 = f32[] parameter(0)");
|
|
}
|
|
|
|
TEST_F(PatternMatcherTest, Comparison) {
|
|
auto shape = ShapeUtil::MakeShape(F32, {1});
|
|
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
|
|
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
|
|
auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
|
|
ComparisonDirection::kEq);
|
|
auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
|
|
ComparisonDirection::kNe);
|
|
auto add =
|
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
|
|
auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
|
|
ComparisonDirection::kLe);
|
|
|
|
EXPECT_TRUE(Match(eq.get(), m::Compare()));
|
|
EXPECT_TRUE(Match(eq.get(), m::Eq()));
|
|
EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
|
|
EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
|
|
EXPECT_TRUE(Match(ne.get(), m::Compare()));
|
|
EXPECT_TRUE(Match(ne.get(), m::Ne()));
|
|
EXPECT_TRUE(Match(
|
|
le.get(),
|
|
m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
|
|
EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
|
|
m::Add(m::Parameter(0), m::Parameter(1)))));
|
|
|
|
EXPECT_FALSE(Match(eq.get(), m::Add()));
|
|
EXPECT_FALSE(Match(eq.get(), m::Ne()));
|
|
EXPECT_FALSE(
|
|
Match(le.get(),
|
|
m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
|
|
EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
|
|
EXPECT_DESC_AND_EXPLANATION(
|
|
eq, m::Ne().WithOneUser(),
|
|
"an HloInstruction:\n"
|
|
" * with opcode compare AND\n"
|
|
" * which has comparison direction NE AND\n"
|
|
" * which has exactly one user (but possibly is used "
|
|
"multiple times by that instruction)",
|
|
"HloInstruction is not comparison NE\n"
|
|
"in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
|
|
"direction=EQ");
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace xla
|