[MLIR][NFC] Adopt variadic isa<> on MLIR Types and Attributes

- Also adopt variadic llvm::isa<> in more places

PiperOrigin-RevId: 320206113
Change-Id: Ia03a1503f699fb6be6dff02e90b6630d6d894b19
This commit is contained in:
Rahul Joshi 2020-07-08 09:46:21 -07:00 committed by TensorFlower Gardener
parent 4e1aa305a1
commit e2f8269f17
15 changed files with 24 additions and 35 deletions

View File

@ -443,8 +443,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
TF_ASSIGN_OR_RETURN(value,
ConvertFloatBuffer(shaped_type, float_type, buffer));
} else if (elem_type.isa<mlir::IntegerType>() ||
elem_type.isa<QuantizedType>()) {
} else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
TF_ASSIGN_OR_RETURN(value,
ConvertIntBuffer(shaped_type, elem_type, buffer));
} else if (elem_type.isa<mlir::TF::StringType>()) {
@ -456,8 +455,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
refs.push_back({ref.data(), ref.size()});
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
} else if (elem_type.isa<mlir::ComplexType>() ||
elem_type.isa<mlir::TF::TensorFlowType>()) {
} else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
auto dialect = elem_type.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);

View File

@ -694,8 +694,7 @@ void QuantizationDriver::SetupAllStates() {
fn_.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::DequantizeCastOp>(op) ||
llvm::isa<quant::QuantizeCastOp>(op))
llvm::isa<quant::DequantizeCastOp, quant::QuantizeCastOp>(op))
return;
work_list_.push_back(op);

View File

@ -386,8 +386,7 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure();
if (llvm::isa<FixedOutputRangeInterface>(def) ||
llvm::isa<SameScalesOpInterface>(def) ||
if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure();
}

View File

@ -560,7 +560,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
return failure();
ShapedType filter_type = filter_cst.getType();
if (llvm::isa<AddOp>(binary_op) || llvm::isa<SubOp>(binary_op)) {
if (llvm::isa<AddOp, SubOp>(binary_op)) {
auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
if (padding && padding.getValue() != "VALID") return failure();
@ -606,7 +606,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
fc_op.setOperand(0, binary_op->getOperand(0));
fc_op.setOperand(2, new_bias_op);
} else if (llvm::isa<MulOp>(binary_op) || llvm::isa<DivOp>(binary_op)) {
} else if (llvm::isa<MulOp, DivOp>(binary_op)) {
// The fusion of mul/div is actually applying the following
// transformation:
// w * (x ' c) + b => (w ' c) x + b

View File

@ -168,8 +168,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
var_handle.resource(),
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
&var_handle_name_id_map));
} else if (llvm::isa<TF::IdentityNOp>(op) ||
llvm::isa<TF::IdentityOp>(op)) {
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
for (auto operand_and_result :
llvm::zip(op->getOperands(), op->getResults())) {
forward_input_to_output(std::get<0>(operand_and_result),
@ -333,7 +332,7 @@ bool OpIsDeclaration(Operation* op,
const ResourceAliasAnalysis& alias_analysis) {
// TODO(yuanzx): Add other types of resources.
return llvm::isa<TF::VarHandleOp>(op) ||
((llvm::isa<TF::IdentityNOp>(op) || llvm::isa<TF::IdentityOp>(op)) &&
(llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) &&
!FindAccessedResources(op, alias_analysis).empty());
}

View File

@ -71,7 +71,7 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
// Allow inlining into tf.island regions if the incoming region has a single
// block.
return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
std::next(src->begin()) == src->end();
llvm::hasSingleElement(*src);
}
};

View File

@ -1168,8 +1168,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result,
ShapedType type;
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
return ConstOp::build(builder, result, elem_attr);
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
value.isa<IntegerAttr>()) {
} else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
// All TensorFlow types must be tensor types. In the build() method,
// we want to provide more flexibility by allowing attributes of scalar
// types. But we need to wrap it up with ElementsAttr to construct

View File

@ -356,7 +356,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
Operation *op, NamedAttribute named_attr) {
if (named_attr.first == "tf_saved_model.exported_names") {
if (!isa<FuncOp>(op) && !isa<GlobalTensorOp>(op)) {
if (!isa<FuncOp, GlobalTensorOp>(op)) {
return op->emitError() << "'tf_saved_model.exported_names' must be on a "
"'func' or 'tf_saved_model.global_tensor' op";
}

View File

@ -90,8 +90,7 @@ class TensorFlowType : public Type {
// Returns true if the specified type is a valid TensorFlow element type.
static inline bool IsValidTFElementType(Type type) {
return type.isa<ComplexType>() || type.isa<FloatType>() ||
type.isa<IntegerType>() || type.isa<TensorFlowType>();
return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
}
// Returns true if this is a valid TensorFlow tensor type.

View File

@ -375,7 +375,7 @@ LogicalResult FindResourceArgUseInfo(
info.data_type = assign.value().getType();
continue;
}
if (isa<TF::StackPushV2Op>(user) || isa<TF::StackPopV2Op>(user)) {
if (isa<TF::StackPushV2Op, TF::StackPopV2Op>(user)) {
// Stacks will be handled by a separate pass.
do_not_touch = true;
break;

View File

@ -205,9 +205,9 @@ GetSubtypes(Type type) {
// Returns whether type can be further refined.
bool CanBeRefined(Type type) {
auto shape_type = type.dyn_cast<ShapedType>();
return shape_type && (!shape_type.hasStaticShape() ||
shape_type.getElementType().isa<TF::ResourceType>() ||
shape_type.getElementType().isa<TF::VariantType>());
return shape_type &&
(!shape_type.hasStaticShape() ||
shape_type.getElementType().isa<TF::ResourceType, TF::VariantType>());
}
// Infers the shape from a (Stateful)PartionedCall operation by looking up the
@ -712,8 +712,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
// The shape function of these ops sometimes does not propagate subtypes
// (handle shapes) for resource and variant types. We use a simple passthrough
// to make sure they are preserved in the output.
if (isa<TF::IdentityOp>(op) || isa<TF::IdentityNOp>(op) ||
isa<TF::ZerosLikeOp>(op) || isa<TF::WhileOp>(op)) {
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp, TF::WhileOp>(op)) {
return RefineTypeForPassThroughOperands(op, op->getOperands(),
op->getResults());
}
@ -729,7 +728,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
// Handle call operations by looking up callee and infering return shape as
// needed.
if (isa<PartitionedCallOp>(op) || isa<StatefulPartitionedCallOp>(op))
if (isa<PartitionedCallOp, StatefulPartitionedCallOp>(op))
return InferShapeForCall(op);
// tf.Cast are only inferred if they have at least one user in the TF dialect
@ -889,8 +888,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
};
auto new_element_type = shaped_type.getElementType();
// Populate the handle shapes for a resource/variant.
if (new_element_type.isa<TF::ResourceType>() ||
new_element_type.isa<TF::VariantType>()) {
if (new_element_type.isa<TF::ResourceType, TF::VariantType>()) {
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
if (handle_shapes_types) {
SmallVector<TensorType, 1> subtypes;

View File

@ -488,7 +488,7 @@ LogicalResult DecomposeStackOpsInternal(
llvm::StringMap<PartitionedCallStackOpsInfo>*
decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
// Removes identity nodes in the block. The device computation does not
// need such nodes to carry information.
op.replaceAllUsesWith(op.getOperands());

View File

@ -809,7 +809,7 @@ LogicalResult DecomposeTensorArrayOps(
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {

View File

@ -495,8 +495,7 @@ void TPUClusterFormation::runOnFunction() {
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
auto remove_result = getFunction().walk([&](Operation* op) {
if (!llvm::isa<TF::TPUReplicatedInputOp>(op) &&
!llvm::isa<TF::TPUReplicatedOutputOp>(op))
if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
return WalkResult::advance();
// Forward operand to result. When `num_replicas` attribute is 1, no

View File

@ -576,9 +576,8 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
// Adds nodes for operations.
for (Operation& inst : graph_op.GetBody()) {
for (auto type : inst.getResultTypes())
if (!type.isa<mlir::TensorType>() &&
!type.isa<mlir::tf_executor::ControlType>() &&
!type.isa<mlir::tf_executor::TokenType>())
if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
mlir::tf_executor::TokenType>())
return errors::InvalidArgument(
"Values must be of tensor type, TensorFlow control type, or "
"TensorFlow token type. Found ",