[XLA] Extend patter matcher with a rule for matching a custom call target

PiperOrigin-RevId: 254093268
This commit is contained in:
George Karpenkov 2019-06-19 16:20:04 -07:00 committed by TensorFlower Gardener
parent c2890a06ac
commit c8f1b3826f
2 changed files with 55 additions and 0 deletions

View File

@ -1247,6 +1247,32 @@ class HloInstructionPatternOpcodeImpl {
bool invert_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has a given custom call target.
class HloInstructionCustomCallTargetImpl {
public:
explicit HloInstructionCustomCallTargetImpl(
absl::string_view custom_call_target)
: custom_call_target_(custom_call_target) {}
bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
if (inst->opcode() != HloOpcode::kCustomCall ||
inst->custom_call_target() != custom_call_target_) {
EXPLAIN << "HloInstruction is not a custom call with a target '"
<< custom_call_target_ << "'";
return false;
}
return true;
}
void DescribeTo(std::ostream* os, int64 indent = 0) const {
*os << "custom call with target '" << custom_call_target_ << "'";
}
private:
std::string custom_call_target_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has the given number of operands.
class HloInstructionPatternNumOperandsImpl {
@ -1840,6 +1866,13 @@ class HloInstructionPattern {
return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
}
// Modifies the pattern to match only the custom call with a given target.
auto WithCustomCallTarget(absl::string_view custom_call_target) const
-> decltype(this->AppendImpl(
HloInstructionCustomCallTargetImpl(custom_call_target))) {
return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_target));
}
auto WithNumOperands(int64 num_operands) const -> decltype(
this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) {
return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));

View File

@ -565,6 +565,28 @@ TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
"Layout has format DENSE but expected SPARSE");
}
TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
constexpr char kModuleStr[] = R"(
HloModule test_module
ENTRY test {
ROOT out = f32[] custom-call(), custom_call_target="test_target"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target")));
EXPECT_DESC_AND_EXPLANATION(
root, match::Op().WithCustomCallTarget("other_target"),
"an HloInstruction custom call with target 'other_target'",
"HloInstruction is not a custom call with a target 'other_target'\nin "
"out = f32[] custom-call(), custom_call_target=\"test_target\"");
}
TEST(PatternMatcherTest, ShapeDescribeToAndExplain) {
auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
auto layout = shape.layout();