[MLIR][NFC] Adopt variadic isa<> on MLIR Types and Attributes
- Also adopt variadic llvm::isa<> in more places PiperOrigin-RevId: 320206113 Change-Id: Ia03a1503f699fb6be6dff02e90b6630d6d894b19
This commit is contained in:
parent
4e1aa305a1
commit
e2f8269f17
@ -443,8 +443,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
|
||||
TF_ASSIGN_OR_RETURN(value,
|
||||
ConvertFloatBuffer(shaped_type, float_type, buffer));
|
||||
} else if (elem_type.isa<mlir::IntegerType>() ||
|
||||
elem_type.isa<QuantizedType>()) {
|
||||
} else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
|
||||
TF_ASSIGN_OR_RETURN(value,
|
||||
ConvertIntBuffer(shaped_type, elem_type, buffer));
|
||||
} else if (elem_type.isa<mlir::TF::StringType>()) {
|
||||
@ -456,8 +455,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
refs.push_back({ref.data(), ref.size()});
|
||||
|
||||
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
|
||||
} else if (elem_type.isa<mlir::ComplexType>() ||
|
||||
elem_type.isa<mlir::TF::TensorFlowType>()) {
|
||||
} else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
|
||||
auto dialect = elem_type.getContext()->getRegisteredDialect("tf");
|
||||
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
|
||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||
|
@ -694,8 +694,7 @@ void QuantizationDriver::SetupAllStates() {
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator() ||
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op) ||
|
||||
llvm::isa<quant::QuantizeCastOp>(op))
|
||||
llvm::isa<quant::DequantizeCastOp, quant::QuantizeCastOp>(op))
|
||||
return;
|
||||
work_list_.push_back(op);
|
||||
|
||||
|
@ -386,8 +386,7 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
|
||||
|
||||
Operation* def = pre_quantized.getDefiningOp();
|
||||
if (!def) return failure();
|
||||
if (llvm::isa<FixedOutputRangeInterface>(def) ||
|
||||
llvm::isa<SameScalesOpInterface>(def) ||
|
||||
if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
|
||||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -560,7 +560,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
return failure();
|
||||
ShapedType filter_type = filter_cst.getType();
|
||||
|
||||
if (llvm::isa<AddOp>(binary_op) || llvm::isa<SubOp>(binary_op)) {
|
||||
if (llvm::isa<AddOp, SubOp>(binary_op)) {
|
||||
auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
|
||||
if (padding && padding.getValue() != "VALID") return failure();
|
||||
|
||||
@ -606,7 +606,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
|
||||
fc_op.setOperand(0, binary_op->getOperand(0));
|
||||
fc_op.setOperand(2, new_bias_op);
|
||||
} else if (llvm::isa<MulOp>(binary_op) || llvm::isa<DivOp>(binary_op)) {
|
||||
} else if (llvm::isa<MulOp, DivOp>(binary_op)) {
|
||||
// The fusion of mul/div is actually applying the following
|
||||
// transformation:
|
||||
// w * (x ' c) + b => (w ' c) x + b
|
||||
|
@ -168,8 +168,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
||||
var_handle.resource(),
|
||||
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
|
||||
&var_handle_name_id_map));
|
||||
} else if (llvm::isa<TF::IdentityNOp>(op) ||
|
||||
llvm::isa<TF::IdentityOp>(op)) {
|
||||
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
|
||||
for (auto operand_and_result :
|
||||
llvm::zip(op->getOperands(), op->getResults())) {
|
||||
forward_input_to_output(std::get<0>(operand_and_result),
|
||||
@ -333,7 +332,7 @@ bool OpIsDeclaration(Operation* op,
|
||||
const ResourceAliasAnalysis& alias_analysis) {
|
||||
// TODO(yuanzx): Add other types of resources.
|
||||
return llvm::isa<TF::VarHandleOp>(op) ||
|
||||
((llvm::isa<TF::IdentityNOp>(op) || llvm::isa<TF::IdentityOp>(op)) &&
|
||||
(llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) &&
|
||||
!FindAccessedResources(op, alias_analysis).empty());
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,7 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
|
||||
// Allow inlining into tf.island regions if the incoming region has a single
|
||||
// block.
|
||||
return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
|
||||
std::next(src->begin()) == src->end();
|
||||
llvm::hasSingleElement(*src);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1168,8 +1168,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result,
|
||||
ShapedType type;
|
||||
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
|
||||
return ConstOp::build(builder, result, elem_attr);
|
||||
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
|
||||
value.isa<IntegerAttr>()) {
|
||||
} else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
|
||||
// All TensorFlow types must be tensor types. In the build() method,
|
||||
// we want to provide more flexibility by allowing attributes of scalar
|
||||
// types. But we need to wrap it up with ElementsAttr to construct
|
||||
|
@ -356,7 +356,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
|
||||
LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
|
||||
Operation *op, NamedAttribute named_attr) {
|
||||
if (named_attr.first == "tf_saved_model.exported_names") {
|
||||
if (!isa<FuncOp>(op) && !isa<GlobalTensorOp>(op)) {
|
||||
if (!isa<FuncOp, GlobalTensorOp>(op)) {
|
||||
return op->emitError() << "'tf_saved_model.exported_names' must be on a "
|
||||
"'func' or 'tf_saved_model.global_tensor' op";
|
||||
}
|
||||
|
@ -90,8 +90,7 @@ class TensorFlowType : public Type {
|
||||
|
||||
// Returns true if the specified type is a valid TensorFlow element type.
|
||||
static inline bool IsValidTFElementType(Type type) {
|
||||
return type.isa<ComplexType>() || type.isa<FloatType>() ||
|
||||
type.isa<IntegerType>() || type.isa<TensorFlowType>();
|
||||
return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
|
||||
}
|
||||
|
||||
// Returns true if this is a valid TensorFlow tensor type.
|
||||
|
@ -375,7 +375,7 @@ LogicalResult FindResourceArgUseInfo(
|
||||
info.data_type = assign.value().getType();
|
||||
continue;
|
||||
}
|
||||
if (isa<TF::StackPushV2Op>(user) || isa<TF::StackPopV2Op>(user)) {
|
||||
if (isa<TF::StackPushV2Op, TF::StackPopV2Op>(user)) {
|
||||
// Stacks will be handled by a separate pass.
|
||||
do_not_touch = true;
|
||||
break;
|
||||
|
@ -205,9 +205,9 @@ GetSubtypes(Type type) {
|
||||
// Returns whether type can be further refined.
|
||||
bool CanBeRefined(Type type) {
|
||||
auto shape_type = type.dyn_cast<ShapedType>();
|
||||
return shape_type && (!shape_type.hasStaticShape() ||
|
||||
shape_type.getElementType().isa<TF::ResourceType>() ||
|
||||
shape_type.getElementType().isa<TF::VariantType>());
|
||||
return shape_type &&
|
||||
(!shape_type.hasStaticShape() ||
|
||||
shape_type.getElementType().isa<TF::ResourceType, TF::VariantType>());
|
||||
}
|
||||
|
||||
// Infers the shape from a (Stateful)PartionedCall operation by looking up the
|
||||
@ -712,8 +712,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
// The shape function of these ops sometimes does not propagate subtypes
|
||||
// (handle shapes) for resource and variant types. We use a simple passthrough
|
||||
// to make sure they are preserved in the output.
|
||||
if (isa<TF::IdentityOp>(op) || isa<TF::IdentityNOp>(op) ||
|
||||
isa<TF::ZerosLikeOp>(op) || isa<TF::WhileOp>(op)) {
|
||||
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp, TF::WhileOp>(op)) {
|
||||
return RefineTypeForPassThroughOperands(op, op->getOperands(),
|
||||
op->getResults());
|
||||
}
|
||||
@ -729,7 +728,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
|
||||
// Handle call operations by looking up callee and infering return shape as
|
||||
// needed.
|
||||
if (isa<PartitionedCallOp>(op) || isa<StatefulPartitionedCallOp>(op))
|
||||
if (isa<PartitionedCallOp, StatefulPartitionedCallOp>(op))
|
||||
return InferShapeForCall(op);
|
||||
|
||||
// tf.Cast are only inferred if they have at least one user in the TF dialect
|
||||
@ -889,8 +888,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
};
|
||||
auto new_element_type = shaped_type.getElementType();
|
||||
// Populate the handle shapes for a resource/variant.
|
||||
if (new_element_type.isa<TF::ResourceType>() ||
|
||||
new_element_type.isa<TF::VariantType>()) {
|
||||
if (new_element_type.isa<TF::ResourceType, TF::VariantType>()) {
|
||||
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
|
||||
if (handle_shapes_types) {
|
||||
SmallVector<TensorType, 1> subtypes;
|
||||
|
@ -488,7 +488,7 @@ LogicalResult DecomposeStackOpsInternal(
|
||||
llvm::StringMap<PartitionedCallStackOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
|
||||
// Removes identity nodes in the block. The device computation does not
|
||||
// need such nodes to carry information.
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
|
@ -809,7 +809,7 @@ LogicalResult DecomposeTensorArrayOps(
|
||||
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
op.erase();
|
||||
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
|
||||
|
@ -495,8 +495,7 @@ void TPUClusterFormation::runOnFunction() {
|
||||
|
||||
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
|
||||
auto remove_result = getFunction().walk([&](Operation* op) {
|
||||
if (!llvm::isa<TF::TPUReplicatedInputOp>(op) &&
|
||||
!llvm::isa<TF::TPUReplicatedOutputOp>(op))
|
||||
if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
|
||||
return WalkResult::advance();
|
||||
|
||||
// Forward operand to result. When `num_replicas` attribute is 1, no
|
||||
|
@ -576,9 +576,8 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
// Adds nodes for operations.
|
||||
for (Operation& inst : graph_op.GetBody()) {
|
||||
for (auto type : inst.getResultTypes())
|
||||
if (!type.isa<mlir::TensorType>() &&
|
||||
!type.isa<mlir::tf_executor::ControlType>() &&
|
||||
!type.isa<mlir::tf_executor::TokenType>())
|
||||
if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
|
||||
mlir::tf_executor::TokenType>())
|
||||
return errors::InvalidArgument(
|
||||
"Values must be of tensor type, TensorFlow control type, or "
|
||||
"TensorFlow token type. Found ",
|
||||
|
Loading…
x
Reference in New Issue
Block a user