[XLA] Extend patter matcher with a rule for matching a custom call target
PiperOrigin-RevId: 254093268
This commit is contained in:
parent
c2890a06ac
commit
c8f1b3826f
@ -1247,6 +1247,32 @@ class HloInstructionPatternOpcodeImpl {
|
||||
bool invert_;
|
||||
};
|
||||
|
||||
// An HloInstructionPattern implementation that matches only if the instruction
|
||||
// has a given custom call target.
|
||||
class HloInstructionCustomCallTargetImpl {
|
||||
public:
|
||||
explicit HloInstructionCustomCallTargetImpl(
|
||||
absl::string_view custom_call_target)
|
||||
: custom_call_target_(custom_call_target) {}
|
||||
|
||||
bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
|
||||
if (inst->opcode() != HloOpcode::kCustomCall ||
|
||||
inst->custom_call_target() != custom_call_target_) {
|
||||
EXPLAIN << "HloInstruction is not a custom call with a target '"
|
||||
<< custom_call_target_ << "'";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void DescribeTo(std::ostream* os, int64 indent = 0) const {
|
||||
*os << "custom call with target '" << custom_call_target_ << "'";
|
||||
}
|
||||
|
||||
private:
|
||||
std::string custom_call_target_;
|
||||
};
|
||||
|
||||
// An HloInstructionPattern implementation that matches only if the instruction
|
||||
// has the given number of operands.
|
||||
class HloInstructionPatternNumOperandsImpl {
|
||||
@ -1840,6 +1866,13 @@ class HloInstructionPattern {
|
||||
return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
|
||||
}
|
||||
|
||||
// Modifies the pattern to match only the custom call with a given target.
|
||||
auto WithCustomCallTarget(absl::string_view custom_call_target) const
|
||||
-> decltype(this->AppendImpl(
|
||||
HloInstructionCustomCallTargetImpl(custom_call_target))) {
|
||||
return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_target));
|
||||
}
|
||||
|
||||
auto WithNumOperands(int64 num_operands) const -> decltype(
|
||||
this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) {
|
||||
return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
|
||||
|
@ -565,6 +565,28 @@ TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
|
||||
"Layout has format DENSE but expected SPARSE");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
|
||||
constexpr char kModuleStr[] = R"(
|
||||
HloModule test_module
|
||||
|
||||
ENTRY test {
|
||||
ROOT out = f32[] custom-call(), custom_call_target="test_target"
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
|
||||
|
||||
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||
EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
|
||||
EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target")));
|
||||
|
||||
EXPECT_DESC_AND_EXPLANATION(
|
||||
root, match::Op().WithCustomCallTarget("other_target"),
|
||||
"an HloInstruction custom call with target 'other_target'",
|
||||
"HloInstruction is not a custom call with a target 'other_target'\nin "
|
||||
"out = f32[] custom-call(), custom_call_target=\"test_target\"");
|
||||
}
|
||||
|
||||
TEST(PatternMatcherTest, ShapeDescribeToAndExplain) {
|
||||
auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
|
||||
auto layout = shape.layout();
|
||||
|
Loading…
Reference in New Issue
Block a user