[MLIR][NFC] Adopt variadic isa<>

PiperOrigin-RevId: 318279074
Change-Id: I9845b0278737a4d91b0e1e6699ae008d78e76556
This commit is contained in:
Rahul Joshi 2020-06-25 08:45:22 -07:00 committed by TensorFlower Gardener
parent 884b55fe86
commit 9a99c02411
24 changed files with 40 additions and 53 deletions

View File

@ -190,9 +190,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
}
static bool IsConst(Operation* op) {
return isa<mlir::ConstantOp>(op) || isa<mlir::TF::ConstOp>(op) ||
isa<tfl::ConstOp>(op) || isa<tfl::QConstOp>(op) ||
isa<tfl::SparseConstOp>(op) || isa<tfl::SparseQConstOp>(op);
return isa<mlir::ConstantOp, mlir::TF::ConstOp, tfl::ConstOp, tfl::QConstOp,
tfl::SparseConstOp, tfl::SparseQConstOp>(op);
}
template <typename T>

View File

@ -289,8 +289,8 @@ class QuantizationDriver {
llvm::errs() << "\n\n\n" << current_op->getName() << "\n";
}
fn_.walk([&](Operation *op) {
if (llvm::isa<quant::QuantizeCastOp>(op) ||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<ConstantOp>(op))
if (llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp, ConstantOp>(
op))
return;
if (current_op == op) llvm::errs() << "===>>>";
llvm::errs() << op->getName() << " : (";

View File

@ -172,7 +172,7 @@ struct QuantizationPattern : public RewritePattern {
Value quantized_value = op->getResult(0);
for (Operation* quantized_op : quantized_value.getUsers()) {
// If it is requantize op, we shouldn't rewrite this op.
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
if (llvm::isa<Q, DQ>(quantized_op)) {
return failure();
}
@ -180,8 +180,8 @@ struct QuantizationPattern : public RewritePattern {
// ops dialect, we shouldn't rewrite.
if (quantized_op->isKnownTerminator() ||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::QuantizeCastOp>(quantized_op) ||
llvm::isa<quant::DequantizeCastOp>(quantized_op)) {
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(
quantized_op)) {
return failure();
}

View File

@ -49,12 +49,10 @@ using mlir::OwningModuleRef;
using stream_executor::port::StatusOr;
bool IsControlFlowV1Op(Operation* op) {
return mlir::isa<mlir::tf_executor::SwitchOp>(op) ||
mlir::isa<mlir::tf_executor::MergeOp>(op) ||
mlir::isa<mlir::tf_executor::EnterOp>(op) ||
mlir::isa<mlir::tf_executor::ExitOp>(op) ||
mlir::isa<mlir::tf_executor::NextIterationSinkOp>(op) ||
mlir::isa<mlir::tf_executor::NextIterationSourceOp>(op);
return mlir::isa<mlir::tf_executor::SwitchOp, mlir::tf_executor::MergeOp,
mlir::tf_executor::EnterOp, mlir::tf_executor::ExitOp,
mlir::tf_executor::NextIterationSinkOp,
mlir::tf_executor::NextIterationSourceOp>(op);
}
mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {

View File

@ -110,8 +110,7 @@ void DefaultQuantParamsPass::runOnFunction() {
func.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::QuantizeCastOp>(op) ||
llvm::isa<quant::DequantizeCastOp>(op))
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(op))
return;
for (auto res : op->getResults()) {

View File

@ -100,8 +100,7 @@ int64_t FindPassthroughArgumentForReturnValue(int64_t return_index,
value = graph.GetFetch().getOperand(res_num);
} else if (auto island = llvm::dyn_cast<tf_executor::IslandOp>(op)) {
value = island.GetYield().getOperand(res_num);
} else if (llvm::isa<TF::IdentityNOp>(op) ||
llvm::isa<TF::IdentityOp>(op)) {
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
value = op->getOperand(res_num);
} else {
return -1;

View File

@ -48,7 +48,7 @@ struct AnnotateParameterReplication
// tf.IdentityOp or a tf.ReadVariableOp.
Value SkipIdentityAndReadVariable(Value v) {
while (auto op = v.getDefiningOp()) {
if (!(isa<TF::IdentityOp>(op) || isa<TF::ReadVariableOp>(op))) break;
if (!isa<TF::IdentityOp, TF::ReadVariableOp>(op)) break;
v = op->getOperand(0);
}
return v;

View File

@ -219,8 +219,7 @@ llvm::Optional<RankedTensorType> GetElementTypeFromAccess(
auto type_from_callee = GetElementTypeFromAccess(
callee.getArgument(use.getOperandNumber()), module, infer_from_op);
if (type_from_callee.hasValue()) return type_from_callee;
} else if (llvm::isa<TF::IdentityOp>(use.getOwner()) ||
llvm::isa<TF::IdentityNOp>(use.getOwner())) {
} else if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(use.getOwner())) {
auto type_from_alias = GetElementTypeFromAccess(
use.getOwner()->getResult(use.getOperandNumber()), module,
infer_from_op);

View File

@ -49,8 +49,7 @@ LogicalResult ConstantFoldFallbackHook(
}
// Do not execute function calls.
if (llvm::isa<TF::WhileOp>(inst) || llvm::isa<TF::IfOp>(inst) ||
llvm::isa<CallOpInterface>(inst)) {
if (llvm::isa<TF::WhileOp, TF::IfOp, CallOpInterface>(inst)) {
return failure();
}

View File

@ -53,7 +53,7 @@ struct FusedKernelMatcherPass
};
bool IsActivationFunction(Operation *op) {
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
return isa<EluOp, ReluOp, Relu6Op>(op);
}
// Finds and returns an activation op that uses the result of `op`. If there are

View File

@ -96,7 +96,7 @@ class ResourceAnalyzer {
}
func.walk([&](Operation* op) {
if (isa<TF::ReadVariableOp>(op) || isa<ReturnOp>(op)) {
if (isa<TF::ReadVariableOp, ReturnOp>(op)) {
return;
}
if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {

View File

@ -97,8 +97,7 @@ llvm::SmallSet<llvm::StringRef, 1> GetCompositeResourceUserNames(
// the error message are ordered.
llvm::SmallSet<llvm::StringRef, 1> composite_users;
for (Operation* user : resource.getUsers())
if (!llvm::isa<TF::ReadVariableOp>(user) &&
!llvm::isa<TF::AssignVariableOp>(user))
if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(user))
composite_users.insert(user->getName().getStringRef());
return composite_users;

View File

@ -53,8 +53,8 @@ struct ReplicateToIslandPass
// Returns whether op requires `_xla_replica_id` attribute.
bool RequiresReplicaIDAttribute(Operation* op) {
return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp>(op) ||
llvm::isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op);
return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op);
}
// Adds integer attribute that represents replica id for replicated ops that

View File

@ -140,7 +140,7 @@ struct ResourceOpLiftingPass
// such nodes to carry information.
void RemoveIdentity(Block* block) {
for (auto& op : llvm::make_early_inc_range(*block)) {
if (isa<TF::IdentityOp>(&op) || isa<TF::IdentityNOp>(&op)) {
if (isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
}

View File

@ -114,14 +114,12 @@ Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
// Returns if the shape inference pass supports an op outside the TF dialect.
bool IsSupportedNonTFOp(Operation* op) {
return isa<ReturnOp>(op) || isa<tf_device::ReturnOp>(op) ||
isa<tf_executor::EnterOp>(op) || isa<tf_executor::ExitOp>(op) ||
isa<tf_executor::FetchOp>(op) || isa<tf_executor::GraphOp>(op) ||
isa<tf_executor::IslandOp>(op) || isa<tf_executor::LoopCondOp>(op) ||
isa<tf_executor::MergeOp>(op) ||
isa<tf_executor::NextIterationSinkOp>(op) ||
isa<tf_executor::SwitchNOp>(op) || isa<tf_executor::SwitchOp>(op) ||
isa<tf_executor::YieldOp>(op);
return isa<ReturnOp, tf_device::ReturnOp, tf_executor::EnterOp,
tf_executor::ExitOp, tf_executor::FetchOp, tf_executor::GraphOp,
tf_executor::IslandOp, tf_executor::LoopCondOp,
tf_executor::MergeOp, tf_executor::NextIterationSinkOp,
tf_executor::SwitchNOp, tf_executor::SwitchOp,
tf_executor::YieldOp>(op);
}
// Returns whether a cast back would need to be inserted, e.g., whether the

View File

@ -440,7 +440,7 @@ llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
};
for (FuncOp func : funcs) {
for (auto& op : func.front().getOperations()) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
continue;
}

View File

@ -640,7 +640,7 @@ LogicalResult DecomposeTensorListOpsInternal(
decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
// TODO(yuanzx): Add a pass to remove identities in device computation.
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
} else if (auto list = llvm::dyn_cast<TF::EmptyTensorListOp>(&op)) {

View File

@ -52,7 +52,7 @@ Operation* GetOpOfValue(Value value) {
// TODO(b/158596585): Replace this with a cost model analysis.
bool IsTrivialUnaryOperation(Operation* op) {
return llvm::isa<TF::CastOp>(op) || llvm::isa<TF::IdentityOp>(op);
return llvm::isa<TF::CastOp, TF::IdentityOp>(op);
}
// Adds outside compilation attributes to unary ops such as Identity/Cast ops

View File

@ -67,7 +67,7 @@ void GetAdjacentXlaShardingOp(Operation* op,
return;
}
if (llvm::isa<TF::IdentityOp>(op) || llvm::isa<TF::CastOp>(op)) {
if (llvm::isa<TF::IdentityOp, TF::CastOp>(op)) {
for (auto user : op->getUsers())
GetAdjacentXlaShardingOp(user, sharding_op);
}

View File

@ -127,7 +127,7 @@ Value SkipIdentity(Value v, bool allow_other_use,
while (auto result = v.dyn_cast<OpResult>()) {
if (!(allow_other_use || v.hasOneUse())) break;
auto op = result.getDefiningOp();
if (!llvm::isa<TF::IdentityOp>(op) && !llvm::isa<TF::IdentityNOp>(op)) {
if (!llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
break;
}
v = op->getOperand(result.getResultNumber());

View File

@ -306,9 +306,8 @@ void BreakUpIslands::BreakUpIsland(
llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
(*new_control_inputs)[other_island_op].push_back(sink_island_control);
} else if (owner->getDialect() == island_op.getDialect() &&
!llvm::isa<tf_executor::GraphOp>(owner) &&
!llvm::isa<tf_executor::YieldOp>(owner) &&
!llvm::isa<tf_executor::NextIterationSourceOp>(owner)) {
!llvm::isa<tf_executor::GraphOp, tf_executor::YieldOp,
tf_executor::NextIterationSourceOp>(owner)) {
(*new_control_inputs)[owner].push_back(sink_island_control);
} else {
owner->emitOpError("adding control dependency not supported");

View File

@ -1060,7 +1060,7 @@ LogicalResult ConvertToHloModule::Lower(
return success();
}
if (isa<xla_hlo::ReturnOp>(inst) || isa<mlir::ReturnOp>(inst)) {
if (isa<xla_hlo::ReturnOp, mlir::ReturnOp>(inst)) {
// Construct the return value for the function. If there are multiple
// values returned, then create a tuple, else return value directly.
xla::XlaOp return_value;

View File

@ -193,8 +193,7 @@ mlir::Operation* HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,
const bool any_op_is_loop_variant = [&] {
for (mlir::Operation& op : llvm::make_range(begin_op, end_op)) {
if (mlir::isa<mlir::AffineForOp>(op) ||
mlir::isa<mlir::AffineStoreOp>(op)) {
if (mlir::isa<mlir::AffineForOp, mlir::AffineStoreOp>(op)) {
return true;
}
}

View File

@ -174,8 +174,7 @@ struct DeadTempBufferRemoval
for (auto result : op->getResults()) {
if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) {
// Store and Dealloc is OK.
if (llvm::isa<mlir::StoreOp>(op) ||
llvm::isa<mlir::DeallocOp>(op)) {
if (llvm::isa<mlir::StoreOp, mlir::DeallocOp>(op)) {
return true;
}
// Load without uses is also ok.
@ -225,8 +224,8 @@ struct MoveScalarComputationsIntoGpuLaunch
: mlir::PassWrapper<MoveScalarComputationsIntoGpuLaunch,
mlir::FunctionPass> {
static bool isInliningBeneficiary(mlir::Operation* op) {
return llvm::isa<mlir::ConstantOp>(op) || llvm::isa<mlir::DimOp>(op) ||
llvm::isa<mlir::SelectOp>(op) || llvm::isa<mlir::CmpIOp>(op);
return llvm::isa<mlir::ConstantOp, mlir::DimOp, mlir::SelectOp,
mlir::CmpIOp>(op);
}
static bool extractBeneficiaryOps(