[XLA] Make pattern_matchers work with gMock.

This lets us unify the HLO pattern matchers and the HLO gmock matchers (in a later patch).

Unifying these two APIs is useful because then we don't have to learn two APIs,
and we don't have to implement features twice.

This change:

 - Adds and tests the DescribeTo and MatchAndExplain APIs (this is the major change)

 - Uses these new gmock matchers in a few tests as a proof of concept.

 - Rewrites the is-constant-scalar API to use a true matcher rather than a std::function predicate matcher.  This is necessary to get a user-friendly DescribeTo message rather than "I don't know what this std::function does."

 - Adds EffectiveScalarConstant helpers along with the old ScalarConstant helpers and then uses these within while_loop_simplifier.

 - Adds some missing simple op matchers: Tuple, Convolution, Pad, etc.

 - Adds a Parameter(n) matcher.

 - Adds Op().Is(), which matches a particular HloInstruction*, which is used in while_loop_simplifier.

 - Updates documentation to reflect new functions (both added here and added in earlier patches).

 - Tightens up the documentation.  It was getting pretty long, and I made it longer.

 - Changes implementation of FooAnyOrder so that it returns an Op rather than an AnyOf.  This lets you do AddAnyOrder(...).IsScalar(), whereas before this was a compile error.

 - Changes the implementation of FooAnyOrder so it uses a custom matcher rather than an AnyOf, in service of better DescribeTo messages.

 - Implements "and" folding, i.e.

     AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>

   in the service of better DescribeTo messages.

PiperOrigin-RevId: 223451504
This commit is contained in:
Justin Lebar 2018-11-29 19:09:19 -08:00 committed by TensorFlower Gardener
parent af4417be82
commit 19a1dd5268
9 changed files with 1838 additions and 247 deletions

View File

@ -408,9 +408,36 @@ tf_cc_test(
":hlo",
":pattern_matcher",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "pattern_matcher_gmock",
testonly = 1,
hdrs = ["pattern_matcher_gmock.h"],
deps = [
":pattern_matcher",
"//tensorflow/compiler/xla:test",
"//tensorflow/core:test",
],
)
tf_cc_test(
name = "pattern_matcher_gmock_test",
srcs = ["pattern_matcher_gmock_test.cc"],
deps = [
":hlo",
":pattern_matcher",
":pattern_matcher_gmock",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
@ -2631,6 +2658,8 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":layout_assignment",
":pattern_matcher",
":pattern_matcher_gmock",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
@ -2775,6 +2804,8 @@ tf_cc_test(
":hlo_matchers",
":hlo_parser",
":hlo_pass",
":pattern_matcher",
":pattern_matcher_gmock",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@ -3535,6 +3566,8 @@ tf_cc_test(
":hlo_casting_utils",
":hlo_matchers",
":hlo_parser",
":pattern_matcher",
":pattern_matcher_gmock",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/core:lib",

View File

@ -22,21 +22,22 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
namespace m = xla::match;
using HloConstantFoldingTest = HloTestBase;
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
@ -49,13 +50,14 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<int64>(),
42);
}
@ -70,13 +72,14 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
42.0f);
}
@ -91,13 +94,14 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({0}), 42);
EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({1}), 19);
}
@ -138,7 +142,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
EXPECT_THAT(root, GmockMatch(m::Constant()));
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
}
}
@ -165,7 +169,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
EXPECT_THAT(root, GmockMatch(m::Constant()));
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
}
@ -190,7 +194,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Constant());
EXPECT_THAT(root, GmockMatch(m::Constant()));
EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape));
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
@ -240,7 +244,8 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
EXPECT_FALSE(result);
EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reduce());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Reduce()));
}
const char* const kConstantFoldLargePad = R"(
@ -260,7 +265,7 @@ TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) {
EXPECT_FALSE(result);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Pad(op::Constant(), op::Constant()));
GmockMatch(m::Pad(m::Constant(), m::Constant())));
}
} // namespace

View File

@ -21,7 +21,8 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -29,7 +30,7 @@ limitations under the License.
namespace xla {
namespace {
namespace op = ::xla::testing::opcode_matchers;
namespace m = ::xla::match;
using absl::string_view;
struct TestData {
@ -1893,7 +1894,8 @@ ENTRY ReduceR3ToR2 {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original));
ASSERT_NE(module->entry_computation(), nullptr);
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Reduce()));
}
TEST_F(HloParserTest, ParseSharding) {
@ -1953,7 +1955,7 @@ TEST(HloParserSingleOpTest, SingleOp) {
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Parameter(0), op::Parameter(1)));
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
}
TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
@ -1981,7 +1983,7 @@ TEST(HloParserSingleOpTest, SingleOpNoNames) {
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Parameter(0), op::Parameter(1)));
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
}
TEST(HloParserSingleOpTest, CanonicalOp) {
@ -1990,7 +1992,7 @@ TEST(HloParserSingleOpTest, CanonicalOp) {
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Parameter(0), op::Parameter(1)));
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
EXPECT_EQ(
computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
text);
@ -2044,7 +2046,11 @@ TEST(HloParserSingleOpTest, SingleOpWithNested) {
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
op::Fusion(op::Parameter(0), op::Parameter(1)));
GmockMatch(m::Op()
.WithOpcode(HloOpcode::kFusion)
.WithNumOperands(2)
.WithOperand(0, m::Parameter(0))
.WithOperand(1, m::Parameter(1))));
}
TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
@ -2088,7 +2094,7 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
op::Convolution(op::Parameter(0), op::Parameter(1)));
GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
auto* convolution =
Cast<HloConvolutionInstruction>(computation->root_instruction());
EXPECT_EQ(convolution->feature_group_count(), 1);
@ -2152,8 +2158,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
module->schedule().is_computation_scheduled(module->entry_computation()));
EXPECT_THAT(
module->schedule().sequence(module->entry_computation()).instructions(),
::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(),
op::Multiply(), op::Parameter(), op::Add()));
::testing::ElementsAre(
GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
GmockMatch(m::Parameter()), GmockMatch(m::Multiply()),
GmockMatch(m::Parameter()), GmockMatch(m::Add())));
}
TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
@ -2179,8 +2187,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
module->schedule().is_computation_scheduled(module->entry_computation()));
EXPECT_THAT(
module->schedule().sequence(module->entry_computation()).instructions(),
::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
op::Broadcast(), op::Multiply(), op::Add()));
::testing::ElementsAre(
GmockMatch(m::Parameter()), GmockMatch(m::Parameter()),
GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
GmockMatch(m::Multiply()), GmockMatch(m::Add())));
}
TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {

View File

@ -31,6 +31,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@ -42,11 +44,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
namespace m = xla::match;
using ::testing::ElementsAre;
class LayoutAssignmentTest : public HloTestBase {
@ -342,7 +343,8 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
// Verify the structure of the HLO graph.
EXPECT_THAT(root,
op::Tuple(op::Tuple(constant), op::Tuple(op::Copy(constant))));
GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
m::Tuple(m::Copy(m::Op().Is(constant))))));
}
TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
@ -946,9 +948,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
HloInstruction* root =
compiled_module->entry_computation()->root_instruction();
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
EXPECT_THAT(root, op::Add(op::Parameter(),
op::Slice(AllOf(op::Copy(op::Parameter(1)),
op::ShapeWithLayout(shape_copy)))));
EXPECT_THAT(
root,
GmockMatch(m::Add(
m::Parameter(),
m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
}
TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
@ -976,10 +980,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
compiled_module->entry_computation()->root_instruction();
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
EXPECT_THAT(root,
op::Add(op::Parameter(),
op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)),
op::ShapeWithLayout(shape_copy)),
op::Parameter(2))));
GmockMatch(m::Add(
m::Parameter(),
m::DynamicSlice(
m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
m::Parameter(2)))));
}
TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
@ -1007,11 +1012,12 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
HloInstruction* root =
compiled_module->entry_computation()->root_instruction();
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
EXPECT_THAT(root,
op::Add(op::Parameter(),
op::Concatenate(AllOf(op::Copy(op::Parameter(1)),
op::ShapeWithLayout(shape_copy)),
op::Parameter(2))));
EXPECT_THAT(
root,
GmockMatch(m::Add(
m::Parameter(),
m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
m::Parameter(2)))));
}
TEST_F(LayoutAssignmentTest,
@ -1038,7 +1044,8 @@ TEST_F(LayoutAssignmentTest,
.ConsumeValueOrDie();
HloInstruction* root =
compiled_module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1)));
EXPECT_THAT(root,
GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
}
TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
@ -1062,8 +1069,9 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
HloInstruction* root =
compiled_module->entry_computation()->root_instruction();
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)),
op::ShapeWithLayout(shape_copy))));
EXPECT_THAT(root,
GmockMatch(m::Slice(
m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
}
TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
@ -1149,7 +1157,7 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
AssignLayouts(m.get(), &computation_layout);
HloInstruction* root = m->entry_computation()->root_instruction();
ASSERT_THAT(root, op::CustomCall(op::Parameter()));
ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
}
@ -1165,7 +1173,7 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
AssignLayouts(m.get(), &computation_layout);
HloInstruction* root = m->entry_computation()->root_instruction();
ASSERT_THAT(root, op::CustomCall(op::Parameter()));
ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
}
@ -1196,7 +1204,7 @@ ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3
// The custom call should be partially encapsulated in kCopy instructions
// because of the layout mismatches.
ASSERT_THAT(m->entry_computation()->root_instruction(),
op::Copy(op::CustomCall(op::Copy(), op::Parameter())));
GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
const HloInstruction* custom_call =
m->entry_computation()->root_instruction()->operand(0);
@ -1222,7 +1230,7 @@ ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
AssignLayouts(m.get(), &computation_layout);
ASSERT_THAT(m->entry_computation()->root_instruction(),
op::Copy(op::CustomCall()));
GmockMatch(m::Copy(m::CustomCall())));
const HloInstruction* custom_call =
m->entry_computation()->root_instruction()->operand(0);
@ -1256,7 +1264,7 @@ ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f
ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
ASSERT_THAT(m->entry_computation()->root_instruction(),
op::Copy(op::CustomCall(op::Tuple())));
GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
const HloInstruction* custom_call =
m->entry_computation()->root_instruction()->operand(0);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,92 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_
#include <ostream>
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace pattern_matcher_gmock_detail {
template <typename Pattern>
class GmockMatcher {
public:
explicit GmockMatcher(Pattern p) : pattern_(std::move(p)) {}
// In service of better error messages, list out the overloads explicitly
// rather than just using a template. gMock's polymorphism plus
// pattern_matcher yields some pretty gnarly stuff.
bool MatchAndExplain(const Layout& l,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(&l, listener);
}
bool MatchAndExplain(const Layout* l,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(l, listener);
}
bool MatchAndExplain(const Shape& s,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(&s, listener);
}
bool MatchAndExplain(const Shape* s,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(s, listener);
}
bool MatchAndExplain(const HloInstruction& instr,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(&instr, listener);
}
bool MatchAndExplain(const HloInstruction* instr,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(instr, listener);
}
void DescribeTo(std::ostream* os) const { pattern_.DescribeTo(os); }
void DescribeNegationTo(std::ostream* os) const {
*os << "is NOT: ";
DescribeTo(os);
}
private:
template <typename T>
bool MatchAndExplainImpl(const T* t,
::testing::MatchResultListener* listener) const {
MatchOption options{/*.capture=*/true, /*.explain_os=*/listener->stream()};
return Match(t, pattern_, options);
}
Pattern pattern_;
};
} // namespace pattern_matcher_gmock_detail
template <typename Pattern>
::testing::PolymorphicMatcher<
pattern_matcher_gmock_detail::GmockMatcher<Pattern>>
GmockMatch(Pattern&& p) {
return ::testing::MakePolymorphicMatcher(
pattern_matcher_gmock_detail::GmockMatcher<Pattern>(
std::forward<Pattern>(p)));
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_

View File

@ -0,0 +1,76 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
namespace m = ::xla::match;
using ::testing::Eq;
using ::testing::Not;
template <typename MatchedTy>
string Describe(const ::testing::Matcher<MatchedTy>& m) {
std::stringstream ss;
m.DescribeTo(&ss);
return ss.str();
}
template <typename MatchedTy>
string Explain(
const MatchedTy& val,
const ::testing::Matcher<typename std::remove_cv<MatchedTy>::type>& m) {
::testing::StringMatchResultListener listener;
EXPECT_THAT(val, ::testing::Not(m)); // For the error message.
EXPECT_FALSE(m.MatchAndExplain(val, &listener));
return listener.str();
}
// This file tests the GmockMatch function. The actual explanation and
// description returned by matchers is tested in pattern_matchers_test.
TEST(PatternMatcherGmock, MatchShape) {
Shape s = ShapeUtil::MakeShape(F32, {10, 100});
// You can pass const Shape& or a const Shape*.
EXPECT_THAT(s, GmockMatch(m::Shape()));
EXPECT_THAT(&s, Not(GmockMatch(m::Shape().WithElementType(F16))));
EXPECT_THAT(Describe<Shape>(GmockMatch(m::Shape().IsArray())),
"a shape that represents an array");
}
TEST(PatternMatcherGmock, MatchLayout) {
Layout l = LayoutUtil::MakeLayout({0, 1});
EXPECT_THAT(l, GmockMatch(m::Layout()));
EXPECT_THAT(&l, Not(GmockMatch(m::Layout().WithSparseFormat())));
EXPECT_THAT(Describe<Layout>(GmockMatch(m::Layout().WithSparseFormat())),
"a layout with format SPARSE");
}
TEST(PatternMatchGmock, MatchInstruction) {
auto instr =
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {42}), "p");
EXPECT_THAT(instr.get(), GmockMatch(m::Parameter()));
EXPECT_THAT(*instr, GmockMatch(m::Parameter(0)));
EXPECT_THAT(*instr, Not(GmockMatch(m::Parameter(1))));
EXPECT_THAT(Describe<HloInstruction*>(GmockMatch(m::Parameter())),
"an HloInstruction with opcode parameter");
}
} // anonymous namespace
} // namespace xla

View File

@ -14,14 +14,18 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
namespace m = match;
TEST(PatternMatcherTest, AddOp) {
constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
ENTRY %two_plus_two_computation () -> f32[] {
@ -229,23 +233,74 @@ TEST(PatternMatcherTest, AnyOf) {
}
TEST(PatternMatcherTest, ConstantScalar) {
using match::ConstantEffectiveScalar;
using match::ConstantScalar;
using match::Op;
using match::Tuple;
constexpr char kModuleStr[] = R"(
HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })";
HloModule test_module
ENTRY test {
a = s32[] constant(1)
b = s32[1,1] constant(s32[1,1]{{2}})
c = s32[1,2] constant(s32[1,2]{{2,2}})
d = f32[] constant(1)
e = f32[] constant(1.25)
ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
})";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
EXPECT_TRUE(Match(root, match::ConstantScalar(42)));
EXPECT_FALSE(Match(root, match::ConstantScalar(41)));
EXPECT_FALSE(Match(root, match::ConstantScalar(0)));
}
const HloInstruction* a = root->operand(0);
const HloInstruction* b = root->operand(1);
const HloInstruction* c = root->operand(2);
const HloInstruction* d = root->operand(3);
const HloInstruction* e = root->operand(4);
EXPECT_TRUE(Match(a, ConstantScalar()));
EXPECT_TRUE(Match(a, ConstantScalar(1)));
EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
EXPECT_FALSE(Match(a, ConstantScalar(2)));
EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
TEST(PatternMatcherTest, NoMatchConstantScalar) {
constexpr char kModuleStr[] = R"(
HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
EXPECT_FALSE(Match(b, ConstantScalar()));
EXPECT_FALSE(Match(b, ConstantScalar(2)));
EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
EXPECT_FALSE(Match(root, match::ConstantScalar(42)));
EXPECT_FALSE(Match(c, ConstantScalar()));
EXPECT_FALSE(Match(c, ConstantScalar(2)));
EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
EXPECT_TRUE(Match(d, ConstantScalar(1)));
EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
EXPECT_FALSE(Match(e, ConstantScalar(1)));
EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
const HloInstruction* instr = nullptr;
EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
EXPECT_EQ(instr, a);
instr = nullptr;
EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
EXPECT_EQ(instr, a);
instr = nullptr;
EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
EXPECT_EQ(instr, a);
instr = nullptr;
EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
EXPECT_EQ(instr, a);
}
TEST(PatternMatcherTest, MultiplyAnyOrder) {
@ -267,6 +322,15 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) {
root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
EXPECT_TRUE(Match(
root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
// Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
// e.g. IsNonConstant() on it.
EXPECT_TRUE(Match(
root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
.IsNonConstant()));
EXPECT_TRUE(
Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
.IsNonConstant()));
}
TEST(PatternMatcherTest, AnyOfShortCircuit) {
@ -315,14 +379,22 @@ TEST(PatternMatcherTest, AllOf) {
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
auto f16_scalar = ShapeUtil::MakeShape(F16, {});
auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16));
ASSERT_TRUE(Match(root, scalar_pattern));
ASSERT_TRUE(Match(root, f16_pattern));
EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern)));
EXPECT_TRUE(Match(root, AllOf<HloInstruction>(f16_pattern, scalar_pattern)));
ASSERT_TRUE(Match(root, f16_compatible_pattern));
EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
f16_compatible_pattern)));
EXPECT_TRUE(
Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
scalar_pattern)));
EXPECT_FALSE(
Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
EXPECT_FALSE(Match(
root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
EXPECT_FALSE(
Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
}
@ -431,5 +503,377 @@ TEST(PatternMatcherTest, TestConcat) {
Reshape(ConstantScalar(4)))));
}
template <typename Pattern>
string Description(const Pattern& pattern) {
std::stringstream ss;
pattern.DescribeTo(&ss);
return ss.str();
}
template <typename Elem, typename Pattern>
string Explanation(Elem* elem, const Pattern& pattern) {
std::stringstream ss;
MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
Match(elem, pattern, options);
return ss.str();
}
template <typename Elem, typename Pattern>
string Explanation(const std::unique_ptr<Elem>& elem, const Pattern& pattern) {
return Explanation(elem.get(), pattern);
}
template <typename Elem, typename Pattern>
string Explanation(const Elem& elem, const Pattern& pattern) {
return Explanation(&elem, pattern);
}
// Helper macro for checking a pattern's description and the explanation printed
// when attempting to match (and presumably failing) on a given object.
//
// We use a macro rather than a function because we want good line numbers in
// errors. We use this rather than writing a helper that returns a pair of
// (description, explanation) and doing something like
//
// EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
//
// because EXPECT_EQ prints a unified diff if multiline string comparison fails,
// while EXPECT_THAT does not. This unified diff makes the errors much easier
// to read.
#define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc, \
expected_explanation) \
do { \
EXPECT_EQ(Description(pattern), (expected_desc)); \
EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
} while (0)
TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
auto layout = LayoutUtil::MakeLayout({1, 2});
auto layout2 = LayoutUtil::MakeLayout({2, 2});
EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
"a layout", "Layout is null");
EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
"a layout equal to {1,2}",
"Layout {2,2} is not equal to expected {1,2}");
EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().WithSparseFormat(),
"a layout with format SPARSE",
"Layout has format DENSE but expected SPARSE");
EXPECT_DESC_AND_EXPLANATION(layout,
m::Layout().EqualTo(&layout).WithSparseFormat(),
"a layout:\n"
" * equal to {1,2} AND\n"
" * with format SPARSE",
"Layout has format DENSE but expected SPARSE");
}
TEST(PatternMatcherTest, ShapeDescribeToAndExplain) {
auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
auto layout = shape.layout();
EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
"a shape", "Shape is null");
EXPECT_DESC_AND_EXPLANATION(
ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
"Shape not equal to f32[1,2]{0,1}\n"
"in f32[1,2]{1,0}");
EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
m::Shape().CompatibleTo(&shape),
"a shape compatible with f32[1,2]",
"Shape not compatible with f32[1,2]\n"
"in f32[2,2]{1,0}");
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
"a shape with element type F16",
"Shape does not have element type F16\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
"a shape that represents a scalar",
"Shape is not a scalar\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
"a shape that represents an array",
"Shape is not an array\n"
"in ()");
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
"a shape that represents a tuple",
"Shape is not a tuple\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
"a shape that is an effective scalar",
"Shape is not an effective scalar\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
"a shape that has 42 dimensions",
"Shape does not have rank 42\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
"a shape that is a scalar",
"Shape is not a scalar\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
"a shape:\n"
" * that has 1 dimension AND\n"
" * that represents an array",
"Shape does not have rank 1\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
m::Shape().IsArray().WithRank(1),
"a shape:\n"
" * that represents an array AND\n"
" * that has 1 dimension",
"Shape is not an array\n"
"in ()");
EXPECT_DESC_AND_EXPLANATION(
ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
m::Shape().WithLayoutEqualTo(&layout),
"a shape with\n a layout equal to {0,1}",
"Layout {1,0} is not equal to expected {0,1}\n"
"in f32[1,2]{1,0}");
EXPECT_DESC_AND_EXPLANATION(
shape, m::Shape().WithLayout(m::Layout().WithSparseFormat()),
"a shape with\n a layout with format SPARSE",
"Layout has format DENSE but expected SPARSE\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(shape,
m::Shape().WithSubshapeEqualTo({10}, &shape),
"a shape with subshape at index {10} which is\n"
" a shape equal to f32[1,2]{0,1}",
"No subshape at {10}\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
m::Shape().WithSubshapeEqualTo({0}, &shape),
"a shape with subshape at index {0} which is\n"
" a shape equal to f32[1,2]{0,1}",
"Shape not equal to f32[1,2]{0,1}\n"
"in f32[2,2]{1,0}\n"
"in subshape at {0}\n"
"in (f32[2,2])");
EXPECT_DESC_AND_EXPLANATION(shape,
m::Shape().WithSubshapeCompatibleTo({10}, &shape),
"a shape with subshape at index {10} which is\n"
" a shape compatible with f32[1,2]",
"No subshape at {10}\n"
"in f32[1,2]{0,1}");
EXPECT_DESC_AND_EXPLANATION(
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
m::Shape().WithSubshapeCompatibleTo({0}, &shape),
"a shape with subshape at index {0} which is\n"
" a shape compatible with f32[1,2]",
"Shape not compatible with f32[1,2]\n"
"in f32[2,2]{1,0}\n"
"in subshape at {0}\n"
"in (f32[2,2])");
EXPECT_DESC_AND_EXPLANATION(
ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
"a shape with subshape at index {0,0} which is\n"
" a shape that represents a scalar",
"Shape is not a scalar\n"
"in f32[1,2]{0,1}\n"
"in subshape at {0,0}\n"
"in ((f32[1,2]))");
}
std::unique_ptr<HloInstruction> SetName(absl::string_view name,
std::unique_ptr<HloInstruction> instr) {
instr->SetAndSanitizeName(string(name));
return instr;
}
TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
std::unique_ptr<HloInstruction> iota =
SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
/*iota_dimension=*/0));
std::unique_ptr<HloInstruction> constant =
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
m::Op(), "an HloInstruction",
"HloInstruction* is null");
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
"an HloInstruction named \"foo\"",
"HloInstruction not named \"foo\"\n"
"in i = s32[42]{0} iota(), iota_dimension=0");
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
"an HloInstruction with opcode add",
"HloInstruction doesn't have opcode add\n"
"in i = s32[42]{0} iota(), iota_dimension=0");
EXPECT_DESC_AND_EXPLANATION(
constant, m::Op().IsNonConstant(),
"an HloInstruction with any opcode other than constant",
"HloInstruction has opcode constant, expected anything else\n"
"in c = s32[] constant(0)");
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
"an HloInstruction with 42 operands",
"HloInstruction doesn't have 42 operands\n"
"in i = s32[42]{0} iota(), iota_dimension=0");
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
"an HloInstruction outputting\n"
" a shape that represents a tuple",
"Shape is not a tuple\n"
"in s32[42]{0}\n"
"in output shape\n"
"in i = s32[42]{0} iota(), iota_dimension=0");
EXPECT_DESC_AND_EXPLANATION(
iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
"an HloInstruction with operand 2 which is:\n"
" an HloInstruction with opcode add",
"desired operand index 2 is out of bounds\n"
"in i = s32[42]{0} iota(), iota_dimension=0");
EXPECT_DESC_AND_EXPLANATION(
SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
HloOpcode::kAdd, constant.get(),
constant.get())),
m::Op().WithOperand(1, m::Op().IsNonConstant()),
"an HloInstruction with operand 1 which is:\n"
" an HloInstruction with any opcode other than constant",
"HloInstruction has opcode constant, expected anything else\n"
"in c = s32[] constant(0)\n"
"in operand 1\n"
"in a = s32[] add(s32[] c, s32[] c)");
EXPECT_DESC_AND_EXPLANATION(
iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
"an HloInstruction with fusion kind kLoop",
"HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
"in i = s32[42]{0} iota(), iota_dimension=0");
EXPECT_DESC_AND_EXPLANATION(
iota, m::Op().WithTupleIndex(42),
"an HloInstruction which is a GTE with index 42",
"HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
"in i = s32[42]{0} iota(), iota_dimension=0");
EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
"an HloInstruction which is a constant scalar",
"HloInstruction is not a constant\n"
"in i = s32[42]{0} iota(), iota_dimension=0");
EXPECT_DESC_AND_EXPLANATION(
SetName("c", HloInstruction::CreateConstant(
LiteralUtil::CreateR1<int>({1, 2}))),
m::Op().IsConstantEffectiveScalar(),
"an HloInstruction which is a constant effective scalar",
"HloInstruction is not an effective scalar\n"
"in c = s32[2]{0} constant({1, 2})");
EXPECT_DESC_AND_EXPLANATION(
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
m::Op().IsConstantScalar(42),
"an HloInstruction which is a constant scalar with value 42",
"HloInstruction's constant value 10 did not match expected value 42\n"
"in c = s32[] constant(10)");
EXPECT_DESC_AND_EXPLANATION(
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
m::Op().IsConstantEffectiveScalar(1.25),
"an HloInstruction which is a constant effective scalar with value 1.25",
"HloInstruction's constant value 2.25 did not match expected value 1.25\n"
"in c = f64[] constant(2.25)");
EXPECT_DESC_AND_EXPLANATION(
constant, m::Op().Is(iota.get()),
absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), " (",
iota->ToShortString(), ")"),
absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
absl::Hex(iota.get()), " (", iota->ToShortString(), ")\n",
"in c = s32[] constant(0)"));
}
TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
EXPECT_DESC_AND_EXPLANATION(
SetName("a", HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kAdd,
SetName("b", HloInstruction::CreateConstant(
LiteralUtil::CreateR0(0)))
.get(),
SetName("c", HloInstruction::CreateConstant(
LiteralUtil::CreateR0(0)))
.get())),
m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
"an HloInstruction:\n"
" * with opcode add AND\n"
" * with two operands in either order:\n"
" - an HloInstruction named \"b\"\n"
" - an HloInstruction named \"bar\"",
"HloInstruction's operands (ignoring order) did not match second "
"matcher. Specifically,\n"
" - an HloInstruction named \"bar\"\n"
"does not match LHS:\n"
" - HloInstruction not named \"bar\"\n"
" in b = s32[] constant(0)\n"
"does not match RHS:\n"
" - HloInstruction not named \"bar\"\n"
" in c = s32[] constant(0)\n"
"in a = s32[] add(s32[] b, s32[] c)");
EXPECT_DESC_AND_EXPLANATION(
SetName("a",
HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kAdd,
HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
SetName("c", HloInstruction::CreateConstant(
LiteralUtil::CreateR0(0)))
.get())),
m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
"an HloInstruction:\n"
" * with opcode add AND\n"
" * with two operands in either order:\n"
" - an HloInstruction which is a constant scalar\n"
" - an HloInstruction with opcode constant",
"HloInstruction's LHS operand did not match either of the two matchers. "
"Specifically,\n"
" - an HloInstruction which is a constant scalar\n"
"does not match LHS:\n"
" - HloInstruction is not a constant\n"
" in p = s32[] parameter(0)\n"
"and\n"
" - an HloInstruction with opcode constant\n"
"does not match LHS:\n"
" - HloInstruction doesn't have opcode constant\n"
" in p = s32[] parameter(0)\n"
"in a = s32[] add(s32[] p, s32[] c)");
}
TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
EXPECT_DESC_AND_EXPLANATION(
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
m::Op().WithName("bar")),
"any of:\n"
" - an HloInstruction named \"foo\" OR\n"
" - an HloInstruction named \"bar\"",
"None of the following matchers succeeded:\n"
"Matcher #1\n"
" - an HloInstruction named \"foo\"\n"
"failed with\n"
" - HloInstruction not named \"foo\"\n"
" in c = s32[] constant(0)\n"
"Matcher #2\n"
" - an HloInstruction named \"bar\"\n"
"failed with\n"
" - HloInstruction not named \"bar\"\n"
" in c = s32[] constant(0)");
}
TEST(PatternMatcherTest, Parameter) {
auto param =
HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
auto non_param =
SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
EXPECT_TRUE(Match(param.get(), m::Parameter()));
EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
"an HloInstruction:\n"
" * with opcode parameter AND\n"
" * which is parameter 1",
"HloInstruction doesn't have opcode parameter\n"
"in c = s32[] constant(0)");
EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "p0"),
m::Parameter(1)),
"HloInstruction is not parameter 1\n"
"in p0 = f32[] parameter(0)");
}
} // namespace
} // namespace xla

View File

@ -526,16 +526,14 @@ static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
// performance by forcing us to copy constants.
absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
for (int i = 0; i < root_operands.size(); i++) {
HloInstruction* instr = root_operands[i];
if (instr->opcode() == HloOpcode::kGetTupleElement &&
instr->tuple_index() == i && instr->operand(0) == while_body_param &&
ShapeUtil::IsScalar(instr->shape())) {
auto tuple_element = while_init->operand(i);
if (tuple_element->IsConstant()) {
VLOG(3) << "Found loop invariant tuple element " << i << " "
<< tuple_element->ToString();
index_to_constant[i] = tuple_element;
}
const HloInstruction* init_tuple_elem = nullptr;
if (Match(root_operands[i],
m::GetTupleElement(m::Op().Is(while_body_param), i)
.WithShape(m::Shape().IsScalar())) &&
Match(while_init->operand(i), m::Constant(&init_tuple_elem))) {
VLOG(3) << "Found loop invariant tuple element " << i << " "
<< init_tuple_elem->ToString();
index_to_constant[i] = init_tuple_elem;
}
}
@ -793,16 +791,11 @@ static StatusOr<HloInstruction*> TryMergeInductionVariables(
// Maps the tuple index of each induction variable to its constant increment.
absl::flat_hash_map<int64, const HloConstantInstruction*> induction_vars;
for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
const auto& elem_shape = while_body_root->operand(i)->shape();
if (!ShapeUtil::IsEffectiveScalar(elem_shape) ||
elem_shape.element_type() != elem_ty) {
continue;
}
HloInstruction* constant;
if (!Match(while_body_root->mutable_operand(i),
m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i),
m::Constant(&constant)))) {
m::ConstantScalar(&constant))
.WithShape(m::Shape().WithElementType(elem_ty)))) {
continue;
}
if (!trip_counter && constant->literal().IsAll(1) &&