Use getElementTypeOrSelf to simplify a pattern
A new overload for getElementTypeOrSelf is added for Operation*. PiperOrigin-RevId: 260017027
This commit is contained in:
parent
de4e50504e
commit
25cd82af1e
@ -42,10 +42,6 @@ limitations under the License.
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TF {
|
namespace TF {
|
||||||
|
|
||||||
namespace {
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TF op helper functions
|
// TF op helper functions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -75,10 +71,11 @@ static inline bool HasRankAtLeast(Value *value, int64_t rank) {
|
|||||||
return ranked_type.getRank() >= rank;
|
return ranked_type.getRank() >= rank;
|
||||||
return type.isa<UnrankedTensorType>();
|
return type.isa<UnrankedTensorType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if the given pair of TensorFlow types can be cast to one
|
// Returns true if the given pair of TensorFlow types can be cast to one
|
||||||
// another. In other words, a single run-time value is legal for both the types.
|
// another. In other words, a single run-time value is legal for both the types.
|
||||||
// For example, tensor<*xf32> and tensor<3xf32> are cast compatible.
|
// For example, tensor<*xf32> and tensor<3xf32> are cast compatible.
|
||||||
bool AreCastCompatible(Type a, Type b) {
|
static bool AreCastCompatible(Type a, Type b) {
|
||||||
if (TensorCastOp::areCastCompatible(a, b)) return true;
|
if (TensorCastOp::areCastCompatible(a, b)) return true;
|
||||||
|
|
||||||
// Variant types may optionally contain subtypes information that need not
|
// Variant types may optionally contain subtypes information that need not
|
||||||
@ -89,6 +86,20 @@ bool AreCastCompatible(Type a, Type b) {
|
|||||||
getElementTypeOrSelf(b).getKind() == TensorFlowTypes::VARIANT;
|
getElementTypeOrSelf(b).getKind() == TensorFlowTypes::VARIANT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns either the element type or type of the result of a single result
|
||||||
|
// operation.
|
||||||
|
// TODO(antiagainst): We need an overload function, which mandates function
|
||||||
|
// name. This is temporary. Remove this post variadic operand support is
|
||||||
|
// improved.
|
||||||
|
static Type getElementTypeOrSelf(Operation *op) {
|
||||||
|
if (op->getNumResults() != 1) return {};
|
||||||
|
return getElementTypeOrSelf(op->getResult(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
||||||
|
} // namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AddOp
|
// AddOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -20,9 +20,7 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
|||||||
|
|
||||||
/// TODO(b/130756570): Support OpBase constraints in PatternRewrites.
|
/// TODO(b/130756570): Support OpBase constraints in PatternRewrites.
|
||||||
def SingleResultAndOperandHaveSameElementType : Constraint<
|
def SingleResultAndOperandHaveSameElementType : Constraint<
|
||||||
CPred<"$0->getResult(0)->getType().cast<ShapedType>()"
|
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
|
||||||
".getElementType() == "
|
|
||||||
"$1->getType().cast<ShapedType>().getElementType()">>;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add op patterns.
|
// Add op patterns.
|
||||||
|
Loading…
Reference in New Issue
Block a user