Use getElementTypeOrSelf to simplify a pattern

A new overload for getElementTypeOrSelf is added for Operation*.

PiperOrigin-RevId: 260017027
This commit is contained in:
Lei Zhang 2019-07-25 14:06:50 -07:00 committed by TensorFlower Gardener
parent de4e50504e
commit 25cd82af1e
2 changed files with 17 additions and 8 deletions

View File

@ -42,10 +42,6 @@ limitations under the License.
namespace mlir {
namespace TF {
namespace {
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
} // namespace
//===----------------------------------------------------------------------===//
// TF op helper functions
//===----------------------------------------------------------------------===//
@ -75,10 +71,11 @@ static inline bool HasRankAtLeast(Value *value, int64_t rank) {
return ranked_type.getRank() >= rank;
return type.isa<UnrankedTensorType>();
}
// Returns true if the given pair of TensorFlow types can be cast to one
// another. In other words, a single run-time value is legal for both the types.
// For example, tensor<*xf32> and tensor<3xf32> are cast compatible.
bool AreCastCompatible(Type a, Type b) {
static bool AreCastCompatible(Type a, Type b) {
if (TensorCastOp::areCastCompatible(a, b)) return true;
// Variant types may optionally contain subtypes information that need not
@ -89,6 +86,20 @@ bool AreCastCompatible(Type a, Type b) {
getElementTypeOrSelf(b).getKind() == TensorFlowTypes::VARIANT;
}
// Returns either the element type or type of the result of a single result
// operation.
// TODO(antiagainst): We need an overload function, which mandates function
// name. This is temporary. Remove this post variadic operand support is
// improved.
static Type getElementTypeOrSelf(Operation *op) {
if (op->getNumResults() != 1) return {};
return getElementTypeOrSelf(op->getResult(0));
}
namespace {
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
} // namespace
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//

View File

@ -20,9 +20,7 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
/// TODO(b/130756570): Support OpBase constraints in PatternRewrites.
def SingleResultAndOperandHaveSameElementType : Constraint<
CPred<"$0->getResult(0)->getType().cast<ShapedType>()"
".getElementType() == "
"$1->getType().cast<ShapedType>().getElementType()">>;
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
//===----------------------------------------------------------------------===//
// Add op patterns.