STT-tensorflow/tensorflow/compiler/xla/service/pattern_matcher_gmock.h
Justin Lebar 19a1dd5268 [XLA] Make pattern_matchers work with gMock.
This lets us unify the HLO pattern matchers and the HLO gmock matchers (in a later patch).

Unifying these two APIs is useful because then we don't have to learn two APIs,
and we don't have to implement features twice.

This change:

 - Adds and tests the DescribeTo and MatchAndExplain APIs (this is the major change)

 - Uses these new gmock matchers in a few tests as a proof of concept.

 - Rewrites the is-constant-scalar API to use a true matcher rather than a std::function predicate matcher.  This is necessary to get a user-friendly DescribeTo message rather than "I don't know what this std::function does."

 - Adds EffectiveScalarConstant helpers along with the old ScalarConstant helpers and then uses these within while_loop_simplifier.

 - Adds some missing simple op matchers: Tuple, Convolution, Pad, etc.

 - Adds a Parameter(n) matcher.

 - Adds Op().Is(), which matches a particular HloInstruction*, which is used in while_loop_simplifier.

 - Updates documentation to reflect new functions (both added here and added in earlier patches).

 - Tightens up the documentation.  It was getting pretty long, and I made it longer.

 - Changes implementation of FooAnyOrder so that it returns an Op rather than an AnyOf.  This lets you do AddAnyOrder(...).IsScalar(), whereas before this was a compile error.

 - Changes the implementation of FooAnyOrder so it uses a custom matcher rather than an AnyOf, in service of better DescribeTo messages.

 - Implements "and" folding, i.e.

     AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>

   in the service of better DescribeTo messages.

PiperOrigin-RevId: 223451504
2018-11-29 19:14:05 -08:00

93 lines
3.2 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_
#include <ostream>
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace pattern_matcher_gmock_detail {
template <typename Pattern>
class GmockMatcher {
public:
explicit GmockMatcher(Pattern p) : pattern_(std::move(p)) {}
// In service of better error messages, list out the overloads explicitly
// rather than just using a template. gMock's polymorphism plus
// pattern_matcher yields some pretty gnarly stuff.
bool MatchAndExplain(const Layout& l,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(&l, listener);
}
bool MatchAndExplain(const Layout* l,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(l, listener);
}
bool MatchAndExplain(const Shape& s,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(&s, listener);
}
bool MatchAndExplain(const Shape* s,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(s, listener);
}
bool MatchAndExplain(const HloInstruction& instr,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(&instr, listener);
}
bool MatchAndExplain(const HloInstruction* instr,
::testing::MatchResultListener* listener) const {
return MatchAndExplainImpl(instr, listener);
}
void DescribeTo(std::ostream* os) const { pattern_.DescribeTo(os); }
void DescribeNegationTo(std::ostream* os) const {
*os << "is NOT: ";
DescribeTo(os);
}
private:
template <typename T>
bool MatchAndExplainImpl(const T* t,
::testing::MatchResultListener* listener) const {
MatchOption options{/*.capture=*/true, /*.explain_os=*/listener->stream()};
return Match(t, pattern_, options);
}
Pattern pattern_;
};
} // namespace pattern_matcher_gmock_detail
template <typename Pattern>
::testing::PolymorphicMatcher<
pattern_matcher_gmock_detail::GmockMatcher<Pattern>>
GmockMatch(Pattern&& p) {
return ::testing::MakePolymorphicMatcher(
pattern_matcher_gmock_detail::GmockMatcher<Pattern>(
std::forward<Pattern>(p)));
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_