[XLA] Extend patter matcher with a rule for matching a custom call target
PiperOrigin-RevId: 254093268
This commit is contained in:
parent
c2890a06ac
commit
c8f1b3826f
@ -1247,6 +1247,32 @@ class HloInstructionPatternOpcodeImpl {
|
|||||||
bool invert_;
|
bool invert_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// An HloInstructionPattern implementation that matches only if the instruction
|
||||||
|
// has a given custom call target.
|
||||||
|
class HloInstructionCustomCallTargetImpl {
|
||||||
|
public:
|
||||||
|
explicit HloInstructionCustomCallTargetImpl(
|
||||||
|
absl::string_view custom_call_target)
|
||||||
|
: custom_call_target_(custom_call_target) {}
|
||||||
|
|
||||||
|
bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
|
||||||
|
if (inst->opcode() != HloOpcode::kCustomCall ||
|
||||||
|
inst->custom_call_target() != custom_call_target_) {
|
||||||
|
EXPLAIN << "HloInstruction is not a custom call with a target '"
|
||||||
|
<< custom_call_target_ << "'";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DescribeTo(std::ostream* os, int64 indent = 0) const {
|
||||||
|
*os << "custom call with target '" << custom_call_target_ << "'";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string custom_call_target_;
|
||||||
|
};
|
||||||
|
|
||||||
// An HloInstructionPattern implementation that matches only if the instruction
|
// An HloInstructionPattern implementation that matches only if the instruction
|
||||||
// has the given number of operands.
|
// has the given number of operands.
|
||||||
class HloInstructionPatternNumOperandsImpl {
|
class HloInstructionPatternNumOperandsImpl {
|
||||||
@ -1840,6 +1866,13 @@ class HloInstructionPattern {
|
|||||||
return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
|
return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Modifies the pattern to match only the custom call with a given target.
|
||||||
|
auto WithCustomCallTarget(absl::string_view custom_call_target) const
|
||||||
|
-> decltype(this->AppendImpl(
|
||||||
|
HloInstructionCustomCallTargetImpl(custom_call_target))) {
|
||||||
|
return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_target));
|
||||||
|
}
|
||||||
|
|
||||||
auto WithNumOperands(int64 num_operands) const -> decltype(
|
auto WithNumOperands(int64 num_operands) const -> decltype(
|
||||||
this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) {
|
this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) {
|
||||||
return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
|
return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
|
||||||
|
@ -565,6 +565,28 @@ TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
|
|||||||
"Layout has format DENSE but expected SPARSE");
|
"Layout has format DENSE but expected SPARSE");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
|
||||||
|
constexpr char kModuleStr[] = R"(
|
||||||
|
HloModule test_module
|
||||||
|
|
||||||
|
ENTRY test {
|
||||||
|
ROOT out = f32[] custom-call(), custom_call_target="test_target"
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
|
||||||
|
|
||||||
|
auto* root = hlo_module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
|
||||||
|
EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target")));
|
||||||
|
|
||||||
|
EXPECT_DESC_AND_EXPLANATION(
|
||||||
|
root, match::Op().WithCustomCallTarget("other_target"),
|
||||||
|
"an HloInstruction custom call with target 'other_target'",
|
||||||
|
"HloInstruction is not a custom call with a target 'other_target'\nin "
|
||||||
|
"out = f32[] custom-call(), custom_call_target=\"test_target\"");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(PatternMatcherTest, ShapeDescribeToAndExplain) {
|
TEST(PatternMatcherTest, ShapeDescribeToAndExplain) {
|
||||||
auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
|
auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
|
||||||
auto layout = shape.layout();
|
auto layout = shape.layout();
|
||||||
|
Loading…
Reference in New Issue
Block a user