From c8f1b3826f0811ff93e7f2f831b7d44dd01db236 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 19 Jun 2019 16:20:04 -0700 Subject: [PATCH] [XLA] Extend patter matcher with a rule for matching a custom call target PiperOrigin-RevId: 254093268 --- .../compiler/xla/service/pattern_matcher.h | 33 +++++++++++++++++++ .../xla/service/pattern_matcher_test.cc | 22 +++++++++++++ 2 files changed, 55 insertions(+) diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 958365da886..ca037d3ff96 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -1247,6 +1247,32 @@ class HloInstructionPatternOpcodeImpl { bool invert_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// has a given custom call target. +class HloInstructionCustomCallTargetImpl { + public: + explicit HloInstructionCustomCallTargetImpl( + absl::string_view custom_call_target) + : custom_call_target_(custom_call_target) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kCustomCall || + inst->custom_call_target() != custom_call_target_) { + EXPLAIN << "HloInstruction is not a custom call with a target '" + << custom_call_target_ << "'"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "custom call with target '" << custom_call_target_ << "'"; + } + + private: + std::string custom_call_target_; +}; + // An HloInstructionPattern implementation that matches only if the instruction // has the given number of operands. class HloInstructionPatternNumOperandsImpl { @@ -1840,6 +1866,13 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false)); } + // Modifies the pattern to match only the custom call with a given target. + auto WithCustomCallTarget(absl::string_view custom_call_target) const + -> decltype(this->AppendImpl( + HloInstructionCustomCallTargetImpl(custom_call_target))) { + return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_target)); + } + auto WithNumOperands(int64 num_operands) const -> decltype( this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) { return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands)); diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index cbe8c4a2410..803389d4d34 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -565,6 +565,28 @@ TEST(PatternMatcherTest, LayoutDescribeToAndExplain) { "Layout has format DENSE but expected SPARSE"); } +TEST(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + ENTRY test { + ROOT out = f32[] custom-call(), custom_call_target="test_target" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target"))); + EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target"))); + + EXPECT_DESC_AND_EXPLANATION( + root, match::Op().WithCustomCallTarget("other_target"), + "an HloInstruction custom call with target 'other_target'", + "HloInstruction is not a custom call with a target 'other_target'\nin " + "out = f32[] custom-call(), custom_call_target=\"test_target\""); +} + TEST(PatternMatcherTest, ShapeDescribeToAndExplain) { auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1}); auto layout = shape.layout();