From 25cd82af1e9bf01ada3496e496cc0006f8f99c90 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 25 Jul 2019 14:06:50 -0700 Subject: [PATCH] Use getElementTypeOrSelf to simplify a pattern A new overload for getElementTypeOrSelf is added for Operation*. PiperOrigin-RevId: 260017027 --- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 21 ++++++++++++++----- .../tensorflow/transforms/canonicalize.td | 4 +--- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 41e168b8827..c1824099c3e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -42,10 +42,6 @@ limitations under the License. namespace mlir { namespace TF { -namespace { -#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" -} // namespace - //===----------------------------------------------------------------------===// // TF op helper functions //===----------------------------------------------------------------------===// @@ -75,10 +71,11 @@ static inline bool HasRankAtLeast(Value *value, int64_t rank) { return ranked_type.getRank() >= rank; return type.isa(); } + // Returns true if the given pair of TensorFlow types can be cast to one // another. In other words, a single run-time value is legal for both the types. // For example, tensor<*xf32> and tensor<3xf32> are cast compatible. -bool AreCastCompatible(Type a, Type b) { +static bool AreCastCompatible(Type a, Type b) { if (TensorCastOp::areCastCompatible(a, b)) return true; // Variant types may optionally contain subtypes information that need not @@ -89,6 +86,20 @@ bool AreCastCompatible(Type a, Type b) { getElementTypeOrSelf(b).getKind() == TensorFlowTypes::VARIANT; } +// Returns either the element type or type of the result of a single result +// operation. +// TODO(antiagainst): We need an overload function, which mandates function +// name. This is temporary. Remove this post variadic operand support is +// improved. +static Type getElementTypeOrSelf(Operation *op) { + if (op->getNumResults() != 1) return {}; + return getElementTypeOrSelf(op->getResult(0)); +} + +namespace { +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" +} // namespace + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 473f69f87e7..0653c1d109e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -20,9 +20,7 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" /// TODO(b/130756570): Support OpBase constraints in PatternRewrites. def SingleResultAndOperandHaveSameElementType : Constraint< - CPred<"$0->getResult(0)->getType().cast()" - ".getElementType() == " - "$1->getType().cast().getElementType()">>; + CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>; //===----------------------------------------------------------------------===// // Add op patterns.