Use getElementTypeOrSelf to simplify a pattern
A new overload for getElementTypeOrSelf is added for Operation*. PiperOrigin-RevId: 260017027
This commit is contained in:
parent
de4e50504e
commit
25cd82af1e
@ -42,10 +42,6 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF op helper functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -75,10 +71,11 @@ static inline bool HasRankAtLeast(Value *value, int64_t rank) {
|
||||
return ranked_type.getRank() >= rank;
|
||||
return type.isa<UnrankedTensorType>();
|
||||
}
|
||||
|
||||
// Returns true if the given pair of TensorFlow types can be cast to one
|
||||
// another. In other words, a single run-time value is legal for both the types.
|
||||
// For example, tensor<*xf32> and tensor<3xf32> are cast compatible.
|
||||
bool AreCastCompatible(Type a, Type b) {
|
||||
static bool AreCastCompatible(Type a, Type b) {
|
||||
if (TensorCastOp::areCastCompatible(a, b)) return true;
|
||||
|
||||
// Variant types may optionally contain subtypes information that need not
|
||||
@ -89,6 +86,20 @@ bool AreCastCompatible(Type a, Type b) {
|
||||
getElementTypeOrSelf(b).getKind() == TensorFlowTypes::VARIANT;
|
||||
}
|
||||
|
||||
// Returns either the element type or type of the result of a single result
|
||||
// operation.
|
||||
// TODO(antiagainst): We need an overload function, which mandates function
|
||||
// name. This is temporary. Remove this post variadic operand support is
|
||||
// improved.
|
||||
static Type getElementTypeOrSelf(Operation *op) {
|
||||
if (op->getNumResults() != 1) return {};
|
||||
return getElementTypeOrSelf(op->getResult(0));
|
||||
}
|
||||
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -20,9 +20,7 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
/// TODO(b/130756570): Support OpBase constraints in PatternRewrites.
|
||||
def SingleResultAndOperandHaveSameElementType : Constraint<
|
||||
CPred<"$0->getResult(0)->getType().cast<ShapedType>()"
|
||||
".getElementType() == "
|
||||
"$1->getType().cast<ShapedType>().getElementType()">>;
|
||||
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Add op patterns.
|
||||
|
Loading…
Reference in New Issue
Block a user